From 86e7330fa21cbbd1a382a7d6ff74908bcae2c728 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Thu, 25 Apr 2024 18:55:49 +0800 Subject: [PATCH] test: refactor vdb tests by visitor design pattern (#3838) --- .../vdb/milvus/test_milvus.py | 47 ++++++++----------- .../vdb/qdrant/test_qdrant.py | 45 ++++++------------ .../vdb/test_vector_store.py | 32 +++++++++++++ .../vdb/weaviate/test_weaviate.py | 47 ++++++------------- 4 files changed, 80 insertions(+), 91 deletions(-) diff --git a/api/tests/integration_tests/vdb/milvus/test_milvus.py b/api/tests/integration_tests/vdb/milvus/test_milvus.py index 22ed73987..e829a8e4d 100644 --- a/api/tests/integration_tests/vdb/milvus/test_milvus.py +++ b/api/tests/integration_tests/vdb/milvus/test_milvus.py @@ -1,38 +1,29 @@ -import uuid - from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector -from models.dataset import Dataset from tests.integration_tests.vdb.test_vector_store import ( - get_sample_document, - get_sample_embedding, - get_sample_query_vector, + AbstractTestVector, + get_sample_text, setup_mock_redis, ) -def test_milvus_vector(setup_mock_redis) -> None: - dataset_id = str(uuid.uuid4()) - vector = MilvusVector( - collection_name=Dataset.gen_collection_name_by_id(dataset_id), - config=MilvusConfig( - host='localhost', - port=19530, - user='root', - password='Milvus', +class TestMilvusVector(AbstractTestVector): + def __init__(self): + super().__init__() + self.vector = MilvusVector( + collection_name=self.collection_name, + config=MilvusConfig( + host='localhost', + port=19530, + user='root', + password='Milvus', + ) ) - ) - # create vector - vector.create( - texts=[get_sample_document(dataset_id)], - embeddings=[get_sample_embedding()], - ) + def search_by_full_text(self): + # milvus dos not support full text searching yet in < 2.3.x + hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text()) + assert len(hits_by_full_text) == 0 - # search by vector - hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector()) - assert len(hits_by_vector) >= 1 - # milvus dos not support full text searching yet in < 2.3.x - - # delete vector - vector.delete() +def test_milvus_vector(setup_mock_redis): + TestMilvusVector().run_all_test() diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py index 33e9d55dc..0ef3a253b 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -1,40 +1,23 @@ -import uuid - from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector -from models.dataset import Dataset from tests.integration_tests.vdb.test_vector_store import ( - get_sample_document, - get_sample_embedding, - get_sample_query_vector, - get_sample_text, + AbstractTestVector, setup_mock_redis, ) -def test_qdrant_vector(setup_mock_redis)-> None: - dataset_id = str(uuid.uuid4()) - vector = QdrantVector( - collection_name=Dataset.gen_collection_name_by_id(dataset_id), - group_id=dataset_id, - config=QdrantConfig( - endpoint='http://localhost:6333', - api_key='difyai123456', +class TestQdrantVector(AbstractTestVector): + def __init__(self): + super().__init__() + self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.vector = QdrantVector( + collection_name=self.collection_name, + group_id=self.dataset_id, + config=QdrantConfig( + endpoint='http://localhost:6333', + api_key='difyai123456', + ) ) - ) - # create vector - vector.create( - texts=[get_sample_document(dataset_id)], - embeddings=[get_sample_embedding()], - ) - # search by vector - hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector()) - assert len(hits_by_vector) >= 1 - - # search by full text - hits_by_full_text = vector.search_by_full_text(query=get_sample_text()) - assert len(hits_by_full_text) >= 1 - - # delete vector - vector.delete() +def test_qdrant_vector(setup_mock_redis): + TestQdrantVector().run_all_test() diff --git a/api/tests/integration_tests/vdb/test_vector_store.py b/api/tests/integration_tests/vdb/test_vector_store.py index 536f3c735..ab770be22 100644 --- a/api/tests/integration_tests/vdb/test_vector_store.py +++ b/api/tests/integration_tests/vdb/test_vector_store.py @@ -1,9 +1,11 @@ +import uuid from unittest.mock import MagicMock import pytest from core.rag.models.document import Document from extensions import ext_redis +from models.dataset import Dataset def get_sample_text() -> str: @@ -44,3 +46,33 @@ def setup_mock_redis() -> None: mock_redis_lock.__enter__ = MagicMock() mock_redis_lock.__exit__ = MagicMock() ext_redis.redis_client.lock = mock_redis_lock + + +class AbstractTestVector: + def __init__(self): + self.vector = None + self.dataset_id = str(uuid.uuid4()) + self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + + def create_vector(self) -> None: + self.vector.create( + texts=[get_sample_document(self.dataset_id)], + embeddings=[get_sample_embedding()], + ) + + def search_by_vector(self): + hits_by_vector = self.vector.search_by_vector(query_vector=get_sample_query_vector()) + assert len(hits_by_vector) >= 1 + + def search_by_full_text(self): + hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text()) + assert len(hits_by_full_text) >= 1 + + def delete_vector(self): + self.vector.delete() + + def run_all_test(self): + self.create_vector() + self.search_by_vector() + self.search_by_full_text() + self.delete_vector() diff --git a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py index 1a07d8692..338e33145 100644 --- a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py +++ b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py @@ -1,41 +1,24 @@ -import uuid - from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector from models.dataset import Dataset from tests.integration_tests.vdb.test_vector_store import ( - get_sample_document, - get_sample_embedding, - get_sample_query_vector, - get_sample_text, + AbstractTestVector, setup_mock_redis, ) -def test_weaviate_vector(setup_mock_redis) -> None: - attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] - dataset_id = str(uuid.uuid4()) - vector = WeaviateVector( - collection_name=Dataset.gen_collection_name_by_id(dataset_id), - config=WeaviateConfig( - endpoint='http://localhost:8080', - api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih', - ), - attributes=attributes - ) +class TestWeaviateVector(AbstractTestVector): + def __init__(self): + super().__init__() + self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.vector = WeaviateVector( + collection_name=self.collection_name, + config=WeaviateConfig( + endpoint='http://localhost:8080', + api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih', + ), + attributes=self.attributes + ) - # create vector - vector.create( - texts=[get_sample_document(dataset_id)], - embeddings=[get_sample_embedding()], - ) - # search by vector - hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector()) - assert len(hits_by_vector) >= 1 - - # search by full text - hits_by_full_text = vector.search_by_full_text(query=get_sample_text()) - assert len(hits_by_full_text) >= 1 - - # delete vector - vector.delete() +def test_weaviate_vector(setup_mock_redis): + TestWeaviateVector().run_all_test()