test: add bulk import testcases for full text search (#37197)

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2024-11-27 19:32:42 +08:00 committed by GitHub
parent 8188e1472d
commit e5775a71af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 170 additions and 1 deletions

View File

@ -27,6 +27,7 @@ class DataField:
vec_field = "vectors"
float_vec_field = "float32_vectors"
sparse_vec_field = "sparse_vectors"
bm25_sparse_vec_field = "bm25_sparse_vectors"
image_float_vec_field = "image_float_vec_field"
text_float_vec_field = "text_float_vec_field"
binary_vec_field = "binary_vec_field"

View File

@ -2,7 +2,7 @@ import logging
import random
import time
import pytest
from pymilvus import DataType
from pymilvus import DataType, Function, FunctionType
from pymilvus.bulk_writer import RemoteBulkWriter, BulkFileType
import numpy as np
from pathlib import Path
@ -1602,6 +1602,174 @@ class TestBulkInsert(TestcaseBaseBulkInsert):
assert "name" in fields_from_search
assert "address" in fields_from_search
@pytest.mark.tags(CaseLabel.L3)
@pytest.mark.parametrize("auto_id", [True])
@pytest.mark.parametrize("dim", [128])
@pytest.mark.parametrize("entities", [1000])
@pytest.mark.parametrize("enable_dynamic_field", [True])
@pytest.mark.parametrize("sparse_format", ["doc"])
@pytest.mark.parametrize("file_format", ["parquet", "json"])
def test_with_all_field_and_bm25_function_with_bulk_writer(self, auto_id, dim, entities, enable_dynamic_field, sparse_format, file_format):
"""
target: test bulk insert with all field and bm25 function
method: create collection with all field and bm25 function, then import data with bulk writer
expected: verify data imported correctly
"""
self._connect()
fields = [
cf.gen_int64_field(name=df.pk_field, is_primary=True, auto_id=auto_id),
cf.gen_int64_field(name=df.int_field),
cf.gen_float_field(name=df.float_field),
cf.gen_string_field(name=df.string_field),
cf.gen_string_field(name=df.text_field, enable_analyzer=True, enable_match=True),
cf.gen_json_field(name=df.json_field),
cf.gen_array_field(name=df.array_int_field, element_type=DataType.INT64),
cf.gen_array_field(name=df.array_float_field, element_type=DataType.FLOAT),
cf.gen_array_field(name=df.array_string_field, element_type=DataType.VARCHAR, max_length=100),
cf.gen_array_field(name=df.array_bool_field, element_type=DataType.BOOL),
cf.gen_float_vec_field(name=df.float_vec_field, dim=dim),
cf.gen_sparse_vec_field(name=df.sparse_vec_field),
cf.gen_sparse_vec_field(name=df.bm25_sparse_vec_field),
]
c_name = cf.gen_unique_str("bulk_insert")
schema = cf.gen_collection_schema(fields=fields, auto_id=auto_id, enable_dynamic_field=enable_dynamic_field)
bm25_function = Function(
name="text_bm25_emb",
function_type=FunctionType.BM25,
input_field_names=[df.text_field],
output_field_names=[df.bm25_sparse_vec_field],
params={},
)
schema.add_function(bm25_function)
self.collection_wrap.init_collection(c_name, schema=schema)
documents = []
if file_format == "parquet":
ff = BulkFileType.PARQUET
elif file_format == "json":
ff = BulkFileType.JSON
else:
raise Exception(f"not support file format:{file_format}")
with RemoteBulkWriter(
schema=schema,
remote_path="bulk_data",
connect_param=RemoteBulkWriter.ConnectParam(
bucket_name=self.bucket_name,
endpoint=self.minio_endpoint,
access_key="minioadmin",
secret_key="minioadmin",
),
file_type=ff,
) as remote_writer:
json_value = [
# 1,
# 1.0,
# "1",
# [1, 2, 3],
# ["1", "2", "3"],
# [1, 2, "3"],
{"key": "value"},
]
for i in range(entities):
row = {
df.pk_field: i,
df.int_field: 1,
df.float_field: 1.0,
df.string_field: "string",
df.text_field: fake.text(),
df.json_field: json_value[i%len(json_value)],
df.array_int_field: [1, 2],
df.array_float_field: [1.0, 2.0],
df.array_string_field: ["string1", "string2"],
df.array_bool_field: [True, False],
df.float_vec_field: cf.gen_vectors(1, dim)[0],
df.sparse_vec_field: cf.gen_sparse_vectors(1, dim, sparse_format=sparse_format)[0]
}
if auto_id:
row.pop(df.pk_field)
if enable_dynamic_field:
row["name"] = fake.name()
row["address"] = fake.address()
documents.append(row[df.text_field])
remote_writer.append_row(row)
remote_writer.commit()
files = remote_writer.batch_files
# import data
for f in files:
t0 = time.time()
task_id, _ = self.utility_wrap.do_bulk_insert(
collection_name=c_name, files=f
)
logging.info(f"bulk insert task ids:{task_id}")
success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed(
task_ids=[task_id], timeout=300
)
tt = time.time() - t0
log.info(f"bulk insert state:{success} in {tt} with states:{states}")
assert success
num_entities = self.collection_wrap.num_entities
log.info(f" collection entities: {num_entities}")
assert num_entities == entities
# verify imported data is available for search
index_params = ct.default_index
float_vec_fields = [f.name for f in fields if "vec" in f.name and "float" in f.name]
sparse_vec_fields = [f.name for f in fields if "vec" in f.name and "sparse" in f.name and "bm25" not in f.name]
bm25_sparse_vec_fields = [f.name for f in fields if "vec" in f.name and "sparse" in f.name and "bm25" in f.name]
for f in float_vec_fields:
self.collection_wrap.create_index(
field_name=f, index_params=index_params
)
for f in sparse_vec_fields:
self.collection_wrap.create_index(
field_name=f, index_params=ct.default_sparse_inverted_index
)
for f in bm25_sparse_vec_fields:
self.collection_wrap.create_index(
field_name=f, index_params=ct.default_text_sparse_inverted_index
)
self.collection_wrap.load()
log.info(f"wait for load finished and be ready for search")
time.sleep(2)
# log.info(f"query seg info: {self.utility_wrap.get_query_segment_info(c_name)[0]}")
search_data = cf.gen_vectors(1, dim)
search_params = ct.default_search_params
res, _ = self.collection_wrap.search(
search_data,
df.float_vec_field,
param=search_params,
limit=1,
output_fields=["*"],
check_task=CheckTasks.check_search_results,
check_items={"nq": 1, "limit": 1},
)
for hit in res:
for r in hit:
fields_from_search = r.fields.keys()
for f in fields:
if f.name == df.bm25_sparse_vec_field:
continue
assert f.name in fields_from_search
if enable_dynamic_field:
assert "name" in fields_from_search
assert "address" in fields_from_search
# verify full text search
word_freq = cf.analyze_documents(documents)
token = word_freq.most_common(1)[0][0]
search_data = [f" {token} " + fake.text()]
search_params = ct.default_text_sparse_search_params
res, _ = self.collection_wrap.search(
search_data,
df.bm25_sparse_vec_field,
param=search_params,
limit=1,
output_fields=["*"],
check_task=CheckTasks.check_search_results,
check_items={"nq": 1, "limit": 1},
)
@pytest.mark.tags(CaseLabel.L3)
@pytest.mark.parametrize("auto_id", [True, False])
@pytest.mark.parametrize("dim", [128]) # 128