diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index c0027f3c4..f98ab419c 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource): @account_initialization_required def get(self): vector_type = current_app.config['VECTOR_STORE'] - if vector_type == 'milvus': + if vector_type == 'milvus' or vector_type == 'relyt': return { 'retrieval_method': [ 'semantic_search' @@ -498,7 +498,7 @@ class DatasetRetrievalSettingMockApi(Resource): @account_initialization_required def get(self, vector_type): - if vector_type == 'milvus': + if vector_type == 'milvus' or vector_type == 'relyt': return { 'retrieval_method': [ 'semantic_search' diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index cfd97218b..b9e87d0c4 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -1,16 +1,23 @@ -import logging -from typing import Any +import uuid +from typing import Any, Optional -from pgvecto_rs.sdk import PGVectoRs, Record from pydantic import BaseModel, root_validator +from sqlalchemy import Column, Sequence, String, Table, create_engine, insert from sqlalchemy import text as sql_text +from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base + from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.models.document import Document from extensions.ext_redis import redis_client -logger = logging.getLogger(__name__) +Base = declarative_base() # type: Any + class RelytConfig(BaseModel): host: str @@ -36,16 +43,14 @@ class RelytConfig(BaseModel): class RelytVector(BaseVector): - def __init__(self, collection_name: str, config: RelytConfig, dim: int): + def __init__(self, collection_name: str, config: RelytConfig, group_id: str): super().__init__(collection_name) + self.embedding_dimension = 1536 self._client_config = config self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" - self._client = PGVectoRs( - db_url=self._url, - collection_name=self._collection_name, - dimension=dim - ) + self.client = create_engine(self._url) self._fields = [] + self._group_id = group_id def get_type(self) -> str: return 'relyt' @@ -54,6 +59,7 @@ class RelytVector(BaseVector): index_params = {} metadatas = [d.metadata for d in texts] self.create_collection(len(embeddings[0])) + self.embedding_dimension = len(embeddings[0]) self.add_texts(texts, embeddings) def create_collection(self, dimension: int): @@ -63,21 +69,21 @@ class RelytVector(BaseVector): if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" - with Session(self._client._engine) as session: - drop_statement = sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}") + with Session(self.client) as session: + drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """) session.execute(drop_statement) create_statement = sql_text(f""" - CREATE TABLE IF NOT EXISTS collection_{self._collection_name} ( - id UUID PRIMARY KEY, - text TEXT NOT NULL, - meta JSONB NOT NULL, + CREATE TABLE IF NOT EXISTS "{self._collection_name}" ( + id TEXT PRIMARY KEY, + document TEXT NOT NULL, + metadata JSON NOT NULL, embedding vector({dimension}) NOT NULL ) using heap; """) session.execute(create_statement) index_statement = sql_text(f""" CREATE INDEX {index_name} - ON collection_{self._collection_name} USING vectors(embedding vector_l2_ops) + ON "{self._collection_name}" USING vectors(embedding vector_l2_ops) WITH (options = $$ optimizing.optimizing_threads = 30 segment.max_growing_segment_size = 2000 @@ -92,21 +98,62 @@ class RelytVector(BaseVector): redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): - records = [Record.from_text(d.page_content, e, d.metadata) for d, e in zip(documents, embeddings)] - pks = [str(r.id) for r in records] - self._client.insert(records) - return pks + from pgvecto_rs.sqlalchemy import Vector + + ids = [str(uuid.uuid1()) for _ in documents] + metadatas = [d.metadata for d in documents] + for metadata in metadatas: + metadata['group_id'] = self._group_id + texts = [d.page_content for d in documents] + + # Define the table schema + chunks_table = Table( + self._collection_name, + Base.metadata, + Column("id", TEXT, primary_key=True), + Column("embedding", Vector(len(embeddings[0]))), + Column("document", String, nullable=True), + Column("metadata", JSON, nullable=True), + extend_existing=True, + ) + + chunks_table_data = [] + with self.client.connect() as conn: + with conn.begin(): + for document, metadata, chunk_id, embedding in zip( + texts, metadatas, ids, embeddings + ): + chunks_table_data.append( + { + "id": chunk_id, + "embedding": embedding, + "document": document, + "metadata": metadata, + } + ) + + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == 500: + conn.execute(insert(chunks_table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + conn.execute(insert(chunks_table).values(chunks_table_data)) + + return ids def delete_by_document_id(self, document_id: str): ids = self.get_ids_by_metadata_field('document_id', document_id) if ids: - self._client.delete_by_ids(ids) + self.delete_by_uuids(ids) def get_ids_by_metadata_field(self, key: str, value: str): result = None - with Session(self._client._engine) as session: + with Session(self.client) as session: select_statement = sql_text( - f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'{key}' = '{value}'; " + f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'{key}' = '{value}'; """ ) result = session.execute(select_statement).fetchall() if result: @@ -114,56 +161,140 @@ class RelytVector(BaseVector): else: return None + def delete_by_uuids(self, ids: list[str] = None): + """Delete by vector IDs. + + Args: + ids: List of ids to delete. + """ + from pgvecto_rs.sqlalchemy import Vector + + if ids is None: + raise ValueError("No ids provided to delete.") + + # Define the table schema + chunks_table = Table( + self._collection_name, + Base.metadata, + Column("id", TEXT, primary_key=True), + Column("embedding", Vector(self.embedding_dimension)), + Column("document", String, nullable=True), + Column("metadata", JSON, nullable=True), + extend_existing=True, + ) + + try: + with self.client.connect() as conn: + with conn.begin(): + delete_condition = chunks_table.c.id.in_(ids) + conn.execute(chunks_table.delete().where(delete_condition)) + return True + except Exception as e: + print("Delete operation failed:", str(e)) # noqa: T201 + return False + def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: - self._client.delete_by_ids(ids) + self.delete_by_uuids(ids) def delete_by_ids(self, doc_ids: list[str]) -> None: - with Session(self._client._engine) as session: + + with Session(self.client) as session: + ids_str = ','.join(f"'{doc_id}'" for doc_id in doc_ids) select_statement = sql_text( - f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' in ('{doc_ids}'); " + f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """ ) result = session.execute(select_statement).fetchall() if result: ids = [item[0] for item in result] - self._client.delete_by_ids(ids) + self.delete_by_uuids(ids) def delete(self) -> None: - with Session(self._client._engine) as session: - session.execute(sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}")) + with Session(self.client) as session: + session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";""")) session.commit() def text_exists(self, id: str) -> bool: - with Session(self._client._engine) as session: + with Session(self.client) as session: select_statement = sql_text( - f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; " + f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = '{id}' limit 1; """ ) result = session.execute(select_statement).fetchall() return len(result) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - from pgvecto_rs.sdk import filters - filter_condition = filters.meta_contains(kwargs.get('filter')) - results = self._client.search( - top_k=int(kwargs.get('top_k')), + results = self.similarity_search_with_score_by_vector( + k=int(kwargs.get('top_k')), embedding=query_vector, - filter=filter_condition + filter=kwargs.get('filter') ) # Organize results. docs = [] - for record, dis in results: - metadata = record.meta - metadata['score'] = dis + for document, score in results: score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 - if dis > score_threshold: - doc = Document(page_content=record.text, - metadata=metadata) - docs.append(doc) + if score > score_threshold: + docs.append(document) return docs + def similarity_search_with_score_by_vector( + self, + embedding: list[float], + k: int = 4, + filter: Optional[dict] = None, + ) -> list[tuple[Document, float]]: + # Add the filter if provided + try: + from sqlalchemy.engine import Row + except ImportError: + raise ImportError( + "Could not import Row from sqlalchemy.engine. " + "Please 'pip install sqlalchemy>=1.4'." + ) + + filter_condition = "" + if filter is not None: + conditions = [ + f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1 + else f"metadata->>{key!r} = {value[0]!r}" + for key, value in filter.items() + ] + filter_condition = f"WHERE {' AND '.join(conditions)}" + + # Define the base query + sql_query = f""" + set vectors.enable_search_growing = on; + set vectors.enable_search_write = on; + SELECT document, metadata, embedding <-> :embedding as distance + FROM "{self._collection_name}" + {filter_condition} + ORDER BY embedding <-> :embedding + LIMIT :k + """ + + # Set up the query parameters + embedding_str = ", ".join(format(x) for x in embedding) + embedding_str = "[" + embedding_str + "]" + params = {"embedding": embedding_str, "k": k} + + # Execute the query and fetch the results + with self.client.connect() as conn: + results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall() + + documents_with_scores = [ + ( + Document( + page_content=result.document, + metadata=result.metadata, + ), + result.distance, + ) + for result in results + ] + return documents_with_scores + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # milvus/zilliz/relyt doesn't support bm25 search return [] diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index b6ec7a11f..e7d3ca42b 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -126,7 +126,6 @@ class Vector: "vector_store": {"class_prefix": collection_name} } self._dataset.index_struct = json.dumps(index_struct_dict) - dim = len(self._embeddings.embed_query("hello relyt")) return RelytVector( collection_name=collection_name, config=RelytConfig( @@ -136,7 +135,7 @@ class Vector: password=config.get('RELYT_PASSWORD'), database=config.get('RELYT_DATABASE'), ), - dim=dim + group_id=self._dataset.id ) else: raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1f2b37464..c32306e6d 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -86,7 +86,7 @@ services: AZURE_BLOB_ACCOUNT_KEY: 'difyai' AZURE_BLOB_CONTAINER_NAME: 'difyai-container' AZURE_BLOB_ACCOUNT_URL: 'https://.blob.core.windows.net' - # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`. + # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`. VECTOR_STORE: weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. WEAVIATE_ENDPOINT: http://weaviate:8080 @@ -109,6 +109,12 @@ services: MILVUS_PASSWORD: Milvus # The milvus tls switch. MILVUS_SECURE: 'false' + # relyt configurations + RELYT_HOST: db + RELYT_PORT: 5432 + RELYT_USER: postgres + RELYT_PASSWORD: difyai123456 + RELYT_DATABASE: postgres # Mail configuration, support: resend, smtp MAIL_TYPE: '' # default send from email address, if not specified @@ -193,7 +199,7 @@ services: AZURE_BLOB_ACCOUNT_KEY: 'difyai' AZURE_BLOB_CONTAINER_NAME: 'difyai-container' AZURE_BLOB_ACCOUNT_URL: 'https://.blob.core.windows.net' - # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`. + # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`. VECTOR_STORE: weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. WEAVIATE_ENDPOINT: http://weaviate:8080