mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-29 17:58:19 +08:00
parent
97fe817186
commit
6c4e6bf1d6
BIN
api/celerybeat-schedule.db
Normal file
BIN
api/celerybeat-schedule.db
Normal file
Binary file not shown.
@ -56,6 +56,7 @@ DEFAULTS = {
|
||||
'BILLING_ENABLED': 'False',
|
||||
'CAN_REPLACE_LOGO': 'False',
|
||||
'ETL_TYPE': 'dify',
|
||||
'KEYWORD_STORE': 'jieba',
|
||||
'BATCH_UPLOAD_LIMIT': 20
|
||||
}
|
||||
|
||||
@ -183,7 +184,7 @@ class Config:
|
||||
# Currently, only support: qdrant, milvus, zilliz, weaviate
|
||||
# ------------------------
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
|
||||
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
|
@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.login import login_required
|
||||
@ -173,14 +174,14 @@ class DataSourceNotionApi(Resource):
|
||||
if not data_source_binding:
|
||||
raise NotFound('Data source binding not found.')
|
||||
|
||||
loader = NotionLoader(
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token
|
||||
)
|
||||
|
||||
text_docs = loader.load()
|
||||
text_docs = extractor.extract()
|
||||
return {
|
||||
'content': "\n".join([doc.page_content for doc in text_docs])
|
||||
}, 200
|
||||
@ -192,11 +193,30 @@ class DataSourceNotionApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
notion_info_list = args['notion_info_list']
|
||||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type']
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'])
|
||||
return response, 200
|
||||
|
||||
|
||||
|
@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
@ -178,9 +179,9 @@ class DatasetApi(Resource):
|
||||
location='json', store_missing=False,
|
||||
type=_validate_description_length)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
if args['info_list']['data_source_type'] == 'upload_file':
|
||||
file_ids = args['info_list']['file_info_list']['file_ids']
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
@ -278,37 +280,44 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
notion_info_list = args['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type']
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
@ -508,4 +517,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
|
||||
|
@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.document_fields import (
|
||||
@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource):
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get('document_id')
|
||||
|
||||
|
||||
# get default rules
|
||||
mode = DocumentService.DEFAULT_RULES['mode']
|
||||
rules = DocumentService.DEFAULT_RULES['rules']
|
||||
@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
if not file:
|
||||
raise NotFound('File not found.')
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file,
|
||||
document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
|
||||
data_process_rule_dict, document.doc_form,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
info_list = []
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in ['completed', 'error']:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
@ -424,42 +432,48 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
}
|
||||
info_list.append(notion_info)
|
||||
|
||||
if dataset.data_source_type == 'upload_file':
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id.in_(info_list)
|
||||
).all()
|
||||
if document.data_source_type == 'upload_file':
|
||||
file_id = data_source_info['upload_file_id']
|
||||
file_detail = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == 'notion_import':
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info['notion_workspace_id'],
|
||||
"notion_obj_id": data_source_info['notion_page_id'],
|
||||
"notion_page_type": data_source_info['type']
|
||||
},
|
||||
document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
data_process_rule_dict, document.doc_form,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif dataset.data_source_type == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
info_list,
|
||||
data_process_rule_dict,
|
||||
None, 'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response
|
||||
|
||||
|
||||
|
@ -1,107 +0,0 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
from langchain.document_loaders import Docx2txtLoader, TextLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.data_loader.loader.csv_loader import CSVLoader
|
||||
from core.data_loader.loader.excel import ExcelLoader
|
||||
from core.data_loader.loader.html import HTMLLoader
|
||||
from core.data_loader.loader.markdown import MarkdownLoader
|
||||
from core.data_loader.loader.pdf import PdfLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
|
||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
|
||||
return cls.load_from_file(file_path, return_text)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||
upload_file: Optional[UploadFile] = None,
|
||||
is_automatic: bool = False) -> Union[list[Document], str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
file_extension = input_file.suffix.lower()
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
|
||||
else MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension in ['.docx']:
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension == '.msg':
|
||||
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.eml':
|
||||
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.ppt':
|
||||
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.pptx':
|
||||
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.xml':
|
||||
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
|
||||
else TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension in ['.docx']:
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
|
@ -1,34 +0,0 @@
|
||||
import logging
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HTMLLoader(BaseLoader):
|
||||
"""Load html files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
return [Document(page_content=self._load_as_text())]
|
||||
|
||||
def _load_as_text(self) -> str:
|
||||
with open(self._file_path, "rb") as fp:
|
||||
soup = BeautifulSoup(fp, 'html.parser')
|
||||
text = soup.get_text()
|
||||
text = text.strip() if text else ''
|
||||
|
||||
return text
|
@ -1,55 +0,0 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from langchain.document_loaders import PyPDFium2Loader
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfLoader(BaseLoader):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
upload_file: Optional[UploadFile] = None
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._upload_file = upload_file
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_exists = False
|
||||
if self._upload_file:
|
||||
if self._upload_file.hash:
|
||||
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
|
||||
+ self._upload_file.hash + '.0625.plaintext'
|
||||
try:
|
||||
text = storage.load(plaintext_file_key).decode('utf-8')
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
documents = PyPDFium2Loader(file_path=self._file_path).load()
|
||||
text_list = []
|
||||
for document in documents:
|
||||
text_list.append(document.page_content)
|
||||
text = "\n\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||
|
||||
return documents
|
||||
|
@ -1,12 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from langchain.schema import Document
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
@ -1,13 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
|
||||
@ -45,17 +40,6 @@ class AnnotationReplyFeature:
|
||||
embedding_provider_name = collection_binding_detail.provider_name
|
||||
embedding_model_name = collection_binding_detail.model_name
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=app_record.tenant_id,
|
||||
provider=embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=embedding_model_name
|
||||
)
|
||||
|
||||
# get embedding model
|
||||
embeddings = CacheEmbedding(model_instance)
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
@ -71,22 +55,14 @@ class AnnotationReplyFeature:
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings,
|
||||
attributes=['doc_id', 'annotation_id', 'app_id']
|
||||
)
|
||||
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
|
||||
|
||||
documents = vector_index.search(
|
||||
documents = vector.search_by_vector(
|
||||
query=query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': 1,
|
||||
'score_threshold': score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
k=1,
|
||||
score_threshold=score_threshold,
|
||||
filter={
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -1,51 +0,0 @@
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
@classmethod
|
||||
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
|
||||
if indexing_technique == "high_quality":
|
||||
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
|
||||
return None
|
||||
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
return VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
elif indexing_technique == "economy":
|
||||
return KeywordTableIndex(
|
||||
dataset=dataset,
|
||||
config=KeywordTableConfig(
|
||||
max_keywords_per_chunk=10
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError('Unknown indexing technique')
|
||||
|
||||
@classmethod
|
||||
def get_default_high_quality_index(cls, dataset: Dataset):
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=' ')
|
||||
return VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
@ -1,305 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores import VectorStore
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
class BaseVectorIndex(BaseIndex):
|
||||
|
||||
def __init__(self, dataset: Dataset, embeddings: Embeddings):
|
||||
super().__init__(dataset)
|
||||
self._embeddings = embeddings
|
||||
self._vector_store = None
|
||||
|
||||
def get_type(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_index_struct(self) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_vector_store_class(self) -> type:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_full_text_index(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
|
||||
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
||||
|
||||
if search_type == 'similarity_score_threshold':
|
||||
score_threshold = search_kwargs.get("score_threshold")
|
||||
if (score_threshold is None) or (not isinstance(score_threshold, float)):
|
||||
search_kwargs['score_threshold'] = .0
|
||||
|
||||
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
|
||||
query, **search_kwargs
|
||||
)
|
||||
|
||||
docs = []
|
||||
for doc, similarity in docs_with_similarity:
|
||||
doc.metadata['score'] = similarity
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
# similarity k
|
||||
# mmr k, fetch_k, lambda_mult
|
||||
# similarity_score_threshold k
|
||||
return vector_store.as_retriever(
|
||||
search_type=search_type,
|
||||
search_kwargs=search_kwargs
|
||||
).get_relevant_documents(query)
|
||||
|
||||
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
return vector_store.as_retriever(**kwargs)
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
if kwargs.get('duplicate_check', False):
|
||||
texts = self._filter_duplicate_texts(texts)
|
||||
|
||||
uuids = self._get_uuids(texts)
|
||||
vector_store.add_documents(texts, uuids=uuids)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
return vector_store.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
for node_id in ids:
|
||||
vector_store.del_text(node_id)
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
if self.dataset.collection_binding_id:
|
||||
vector_store.delete_by_group_id(group_id)
|
||||
else:
|
||||
vector_store.delete()
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def _is_origin(self):
|
||||
return False
|
||||
|
||||
def recreate_dataset(self, dataset: Dataset):
|
||||
logging.info(f"Recreating dataset {dataset.id}")
|
||||
|
||||
try:
|
||||
self.delete()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
origin_index_struct = self.dataset.index_struct[:]
|
||||
self.dataset.index_struct = None
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create(documents)
|
||||
except Exception as e:
|
||||
self.dataset.index_struct = origin_index_struct
|
||||
raise e
|
||||
|
||||
dataset.index_struct = json.dumps(self.to_index_struct())
|
||||
|
||||
db.session.commit()
|
||||
|
||||
self.dataset = dataset
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def create_qdrant_dataset(self, dataset: Dataset):
|
||||
logging.info(f"create_qdrant_dataset {dataset.id}")
|
||||
|
||||
try:
|
||||
self.delete()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def update_qdrant_dataset(self, dataset: Dataset):
|
||||
logging.info(f"update_qdrant_dataset {dataset.id}")
|
||||
|
||||
segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).first()
|
||||
|
||||
if segment:
|
||||
try:
|
||||
exist = self.text_exists(segment.index_node_id)
|
||||
if exist:
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
|
||||
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.add_texts(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
|
||||
logging.info(f"delete original collection: {dataset.id}")
|
||||
|
||||
self.delete()
|
||||
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
@ -1,165 +0,0 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.milvus_vector_store import MilvusVectorStore
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
secure: bool = False
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
raise ValueError("config MILVUS_HOST is required")
|
||||
if not values['port']:
|
||||
raise ValueError("config MILVUS_PORT is required")
|
||||
if not values['user']:
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values['password']:
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
return {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'user': self.user,
|
||||
'password': self.password,
|
||||
'secure': self.secure
|
||||
}
|
||||
|
||||
|
||||
class MilvusVectorIndex(BaseVectorIndex):
|
||||
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client_config = config
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'milvus'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
index_params = {
|
||||
'metric_type': 'IP',
|
||||
'index_type': "HNSW",
|
||||
'params': {"M": 8, "efConstruction": 64}
|
||||
}
|
||||
self._vector_store = MilvusVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
connection_args=self._client_config.to_milvus_params(),
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = MilvusVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=collection_name,
|
||||
ids=uuids,
|
||||
content_payload_key='page_content'
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
return MilvusVectorStore(
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embedding_function=self._embeddings,
|
||||
connection_args=self._client_config.to_milvus_params()
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return MilvusVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_document_id(document_id)
|
||||
if ids:
|
||||
vector_store.del_texts({
|
||||
'filter': f'id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
vector_store.del_texts({
|
||||
'filter': f'id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_doc_ids(doc_ids)
|
||||
vector_store.del_texts({
|
||||
'filter': f' id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self.dataset.id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# milvus/zilliz doesn't support bm25 search
|
||||
return []
|
@ -1,229 +0,0 @@
|
||||
import os
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import qdrant_client
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http.models import HnswConfigDiff
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.qdrant_vector_store import QdrantVectorStore
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
timeout: float = 20
|
||||
root_path: Optional[str]
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith('path:'):
|
||||
path = self.endpoint.replace('path:', '')
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {
|
||||
'path': path
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'url': self.endpoint,
|
||||
'api_key': self.api_key,
|
||||
'timeout': self.timeout
|
||||
}
|
||||
|
||||
|
||||
class QdrantVectorIndex(BaseVectorIndex):
|
||||
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client_config = config
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'qdrant'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
return dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = QdrantVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
ids=uuids,
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id',
|
||||
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False),
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = QdrantVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=collection_name,
|
||||
ids=uuids,
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id',
|
||||
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False),
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
client = qdrant_client.QdrantClient(
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return QdrantVectorStore(
|
||||
client=client,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embeddings=self._embeddings,
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id'
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return QdrantVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchValue(value=document_id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
for node_id in ids:
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=group_id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self.dataset.id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
return vector_store.similarity_search_by_bm25(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self.dataset.id),
|
||||
),
|
||||
models.FieldCondition(
|
||||
key="page_content",
|
||||
match=models.MatchText(text=query),
|
||||
)
|
||||
],
|
||||
), kwargs.get('top_k', 2))
|
@ -1,90 +0,0 @@
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document
|
||||
|
||||
|
||||
class VectorIndex:
|
||||
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
|
||||
attributes: list = None):
|
||||
if attributes is None:
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
self._dataset = dataset
|
||||
self._embeddings = embeddings
|
||||
self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
|
||||
attributes: list) -> BaseVectorIndex:
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict['type']
|
||||
|
||||
if not vector_type:
|
||||
raise ValueError("Vector store must be specified.")
|
||||
|
||||
if vector_type == "weaviate":
|
||||
from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex
|
||||
|
||||
return WeaviateVectorIndex(
|
||||
dataset=dataset,
|
||||
config=WeaviateConfig(
|
||||
endpoint=config.get('WEAVIATE_ENDPOINT'),
|
||||
api_key=config.get('WEAVIATE_API_KEY'),
|
||||
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
||||
),
|
||||
embeddings=embeddings,
|
||||
attributes=attributes
|
||||
)
|
||||
elif vector_type == "qdrant":
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
|
||||
|
||||
return QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=config.get('QDRANT_URL'),
|
||||
api_key=config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path,
|
||||
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
elif vector_type == "milvus":
|
||||
from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex
|
||||
|
||||
return MilvusVectorIndex(
|
||||
dataset=dataset,
|
||||
config=MilvusConfig(
|
||||
host=config.get('MILVUS_HOST'),
|
||||
port=config.get('MILVUS_PORT'),
|
||||
user=config.get('MILVUS_USER'),
|
||||
password=config.get('MILVUS_PASSWORD'),
|
||||
secure=config.get('MILVUS_SECURE'),
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
if not self._dataset.index_struct_dict:
|
||||
self._vector_index.create(texts, **kwargs)
|
||||
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
self._vector_index.add_texts(texts, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._vector_index is not None:
|
||||
method = getattr(self._vector_index, name)
|
||||
if callable(method):
|
||||
return method
|
||||
|
||||
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
|
||||
|
@ -1,179 +0,0 @@
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import requests
|
||||
import weaviate
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
return values
|
||||
|
||||
|
||||
class WeaviateVectorIndex(BaseVectorIndex):
|
||||
|
||||
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
|
||||
|
||||
weaviate.connect.connection.has_grpc = False
|
||||
|
||||
try:
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint,
|
||||
auth_client_secret=auth_config,
|
||||
timeout_config=(5, 60),
|
||||
startup_period=None
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
|
||||
client.batch.configure(
|
||||
# `batch_size` takes an `int` value to enable auto-batching
|
||||
# (`None` is used for manual batching)
|
||||
batch_size=config.batch_size,
|
||||
# dynamically update the `batch_size` based on import speed
|
||||
dynamic=True,
|
||||
# `timeout_retries` takes an `int` value to retry on time outs
|
||||
timeout_retries=3,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'weaviate'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
attributes = self._attributes
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
|
||||
return WeaviateVectorStore(
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
text_key='text',
|
||||
embedding=self._embeddings,
|
||||
attributes=attributes,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return WeaviateVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": ["document_id"],
|
||||
"valueText": document_id
|
||||
})
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
})
|
||||
|
||||
def delete_by_group_id(self, group_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
|
||||
|
@ -9,20 +9,20 @@ from typing import Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask_login import current_user
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import Document
|
||||
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -31,7 +31,6 @@ from libs import helper
|
||||
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
from models.source import DataSourceBinding
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@ -57,38 +56,19 @@ class IndexingRunner:
|
||||
processing_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||
first()
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
||||
# transform
|
||||
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._step_split(
|
||||
text_docs=text_docs,
|
||||
splitter=splitter,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
self._build_index(
|
||||
# load
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
documents=documents
|
||||
@ -134,39 +114,19 @@ class IndexingRunner:
|
||||
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||
first()
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
# transform
|
||||
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._step_split(
|
||||
text_docs=text_docs,
|
||||
splitter=splitter,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
# build index
|
||||
self._build_index(
|
||||
# load
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
documents=documents
|
||||
@ -220,7 +180,15 @@ class IndexingRunner:
|
||||
documents.append(document)
|
||||
|
||||
# build index
|
||||
self._build_index(
|
||||
# get the process rule
|
||||
processing_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||
first()
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor()
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
documents=documents
|
||||
@ -239,16 +207,16 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
# check document limit
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
if features.billing.enabled:
|
||||
count = len(file_details)
|
||||
count = len(extract_settings)
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
@ -284,16 +252,18 @@ class IndexingRunner:
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
for file_detail in file_details:
|
||||
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
all_text_docs = []
|
||||
for extract_setting in extract_settings:
|
||||
# extract
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||
all_text_docs.extend(text_docs)
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"],
|
||||
rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
|
||||
# load data from file
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
@ -305,7 +275,6 @@ class IndexingRunner:
|
||||
)
|
||||
|
||||
total_segments += len(documents)
|
||||
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
@ -364,154 +333,8 @@ class IndexingRunner:
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
# check document limit
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
if features.billing.enabled:
|
||||
count = len(notion_info_list)
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
|
||||
embedding_model_instance = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
else:
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
# load data from notion
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise ValueError('Data source binding not found.')
|
||||
|
||||
for page in notion_info['pages']:
|
||||
loader = NotionLoader(
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page['page_id'],
|
||||
notion_page_type=page['type']
|
||||
)
|
||||
documents = loader.load()
|
||||
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"],
|
||||
rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._split_to_documents_for_estimate(
|
||||
text_docs=documents,
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
total_segments += len(documents)
|
||||
|
||||
embedding_model_type_instance = None
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
if indexing_technique == 'high_quality' and embedding_model_type_instance:
|
||||
tokens += embedding_model_type_instance.get_num_tokens(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
texts=[document.page_content]
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
|
||||
price_info = model_type_instance.get_price(
|
||||
model=model_instance.model,
|
||||
credentials=model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=total_segments * 2000,
|
||||
)
|
||||
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(price_info.total_amount),
|
||||
"currency": price_info.currency,
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
embedding_price_info = embedding_model_type_instance.get_price(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
total_price = '{:f}'.format(embedding_price_info.total_amount)
|
||||
currency = embedding_price_info.currency
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]:
|
||||
def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
|
||||
-> list[Document]:
|
||||
# load file
|
||||
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
|
||||
return []
|
||||
@ -527,11 +350,27 @@ class IndexingRunner:
|
||||
one_or_none()
|
||||
|
||||
if file_detail:
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=automatic)
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=dataset_document.doc_form
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
||||
elif dataset_document.data_source_type == 'notion_import':
|
||||
loader = NotionLoader.from_document(dataset_document)
|
||||
text_docs = loader.load()
|
||||
|
||||
if (not data_source_info or 'notion_workspace_id' not in data_source_info
|
||||
or 'notion_page_id' not in data_source_info):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info['notion_workspace_id'],
|
||||
"notion_obj_id": data_source_info['notion_page_id'],
|
||||
"notion_page_type": data_source_info['notion_page_type'],
|
||||
"document": dataset_document
|
||||
},
|
||||
document_model=dataset_document.doc_form
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
||||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
@ -545,8 +384,6 @@ class IndexingRunner:
|
||||
# replace doc id to document model id
|
||||
text_docs = cast(list[Document], text_docs)
|
||||
for text_doc in text_docs:
|
||||
# remove invalid symbol
|
||||
text_doc.page_content = self.filter_string(text_doc.page_content)
|
||||
text_doc.metadata['document_id'] = dataset_document.id
|
||||
text_doc.metadata['dataset_id'] = dataset_document.dataset_id
|
||||
|
||||
@ -787,12 +624,12 @@ class IndexingRunner:
|
||||
for q, a in matches if q and a
|
||||
]
|
||||
|
||||
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None:
|
||||
def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
|
||||
dataset_document: DatasetDocument, documents: list[Document]) -> None:
|
||||
"""
|
||||
Build the index for the document.
|
||||
insert index and update document/segment status to completed
|
||||
"""
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
@ -825,13 +662,8 @@ class IndexingRunner:
|
||||
)
|
||||
for document in chunk_documents
|
||||
)
|
||||
|
||||
# save vector index
|
||||
if vector_index:
|
||||
vector_index.add_texts(chunk_documents)
|
||||
|
||||
# save keyword index
|
||||
keyword_table_index.add_texts(chunk_documents)
|
||||
# load index
|
||||
index_processor.load(dataset, chunk_documents)
|
||||
|
||||
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
|
||||
db.session.query(DocumentSegment).filter(
|
||||
@ -911,14 +743,64 @@ class IndexingRunner:
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts(documents, duplicate_check=True)
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.load(dataset, documents)
|
||||
|
||||
# save keyword index
|
||||
index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if index:
|
||||
index.add_texts(documents)
|
||||
def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
|
||||
text_docs: list[Document], process_rule: dict) -> list[Document]:
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
|
||||
process_rule=process_rule)
|
||||
|
||||
return documents
|
||||
|
||||
def _load_segments(self, dataset, dataset_document, documents):
|
||||
# save node to document segment
|
||||
doc_store = DatasetDocumentStore(
|
||||
dataset=dataset,
|
||||
user_id=dataset_document.created_by,
|
||||
document_id=dataset_document.id
|
||||
)
|
||||
|
||||
# add document segments
|
||||
doc_store.add_documents(documents)
|
||||
|
||||
# update document status to indexing
|
||||
cur_time = datetime.datetime.utcnow()
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="indexing",
|
||||
extra_update_params={
|
||||
DatasetDocument.cleaning_completed_at: cur_time,
|
||||
DatasetDocument.splitting_completed_at: cur_time,
|
||||
}
|
||||
)
|
||||
|
||||
# update segment status to indexing
|
||||
self._update_segments_by_document(
|
||||
dataset_document_id=dataset_document.id,
|
||||
update_params={
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.indexing_at: datetime.datetime.utcnow()
|
||||
}
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
|
0
api/core/rag/__init__.py
Normal file
0
api/core/rag/__init__.py
Normal file
38
api/core/rag/cleaner/clean_processor.py
Normal file
38
api/core/rag/cleaner/clean_processor.py
Normal file
@ -0,0 +1,38 @@
|
||||
import re
|
||||
|
||||
|
||||
class CleanProcessor:
|
||||
|
||||
@classmethod
|
||||
def clean(cls, text: str, process_rule: dict) -> str:
|
||||
# default clean
|
||||
# remove invalid symbol
|
||||
text = re.sub(r'<\|', '<', text)
|
||||
text = re.sub(r'\|>', '>', text)
|
||||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
|
||||
# Unicode U+FFFE
|
||||
text = re.sub('\uFFFE', '', text)
|
||||
|
||||
rules = process_rule['rules'] if process_rule else None
|
||||
if 'pre_processing_rules' in rules:
|
||||
pre_processing_rules = rules["pre_processing_rules"]
|
||||
for pre_processing_rule in pre_processing_rules:
|
||||
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
|
||||
# Remove extra spaces
|
||||
pattern = r'\n{3,}'
|
||||
text = re.sub(pattern, '\n\n', text)
|
||||
pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
|
||||
text = re.sub(pattern, ' ', text)
|
||||
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
|
||||
# Remove email
|
||||
pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
|
||||
text = re.sub(pattern, '', text)
|
||||
|
||||
# Remove URL
|
||||
pattern = r'https?://[^\s]+'
|
||||
text = re.sub(pattern, '', text)
|
||||
return text
|
||||
|
||||
def filter_string(self, text):
|
||||
|
||||
return text
|
12
api/core/rag/cleaner/cleaner_base.py
Normal file
12
api/core/rag/cleaner/cleaner_base.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""Abstract interface for document cleaner implementations."""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseCleaner(ABC):
|
||||
"""Interface for clean chunk content.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clean(self, content: str):
|
||||
raise NotImplementedError
|
||||
|
@ -0,0 +1,12 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""clean document content."""
|
||||
from unstructured.cleaners.core import clean_extra_whitespace
|
||||
|
||||
# Returns "ITEM 1A: RISK FACTORS"
|
||||
return clean_extra_whitespace(content)
|
@ -0,0 +1,15 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""clean document content."""
|
||||
import re
|
||||
|
||||
from unstructured.cleaners.core import group_broken_paragraphs
|
||||
|
||||
para_split_re = re.compile(r"(\s*\n\s*){3}")
|
||||
|
||||
return group_broken_paragraphs(content, paragraph_split=para_split_re)
|
@ -0,0 +1,12 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""clean document content."""
|
||||
from unstructured.cleaners.core import clean_non_ascii_chars
|
||||
|
||||
# Returns "This text containsnon-ascii characters!"
|
||||
return clean_non_ascii_chars(content)
|
@ -0,0 +1,11 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""Replaces unicode quote characters, such as the \x91 character in a string."""
|
||||
|
||||
from unstructured.cleaners.core import replace_unicode_quotes
|
||||
return replace_unicode_quotes(content)
|
@ -0,0 +1,11 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredTranslateTextCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""clean document content."""
|
||||
from unstructured.cleaners.translate import translate_text
|
||||
|
||||
return translate_text(content)
|
0
api/core/rag/data_post_processor/__init__.py
Normal file
0
api/core/rag/data_post_processor/__init__.py
Normal file
49
api/core/rag/data_post_processor/data_post_processor.py
Normal file
49
api/core/rag/data_post_processor/data_post_processor.py
Normal file
@ -0,0 +1,49 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||
from core.rag.models.document import Document
|
||||
from core.rerank.rerank import RerankRunner
|
||||
|
||||
|
||||
class DataPostProcessor:
|
||||
"""Interface for data post-processing document.
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False):
|
||||
self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id)
|
||||
self.reorder_runner = self._get_reorder_runner(reorder_enabled)
|
||||
|
||||
def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
|
||||
if self.rerank_runner:
|
||||
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
|
||||
|
||||
if self.reorder_runner:
|
||||
documents = self.reorder_runner.run(documents)
|
||||
|
||||
return documents
|
||||
|
||||
def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=reranking_model['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model['reranking_model_name']
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return None
|
||||
return RerankRunner(rerank_model_instance)
|
||||
return None
|
||||
|
||||
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
|
||||
if reorder_enabled:
|
||||
return ReorderRunner()
|
||||
return None
|
||||
|
||||
|
19
api/core/rag/data_post_processor/reorder.py
Normal file
19
api/core/rag/data_post_processor/reorder.py
Normal file
@ -0,0 +1,19 @@
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
class ReorderRunner:
|
||||
|
||||
def run(self, documents: list[Document]) -> list[Document]:
|
||||
# Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list
|
||||
odd_elements = documents[::2]
|
||||
|
||||
# Retrieve elements from even indices (1, 3, 5, etc.) of the documents list
|
||||
even_elements = documents[1::2]
|
||||
|
||||
# Reverse the list of elements from even indices
|
||||
even_elements_reversed = even_elements[::-1]
|
||||
|
||||
new_documents = odd_elements + even_elements_reversed
|
||||
|
||||
return new_documents
|
0
api/core/rag/datasource/__init__.py
Normal file
0
api/core/rag/datasource/__init__.py
Normal file
21
api/core/rag/datasource/entity/embedding.py
Normal file
21
api/core/rag/datasource/entity/embedding.py
Normal file
@ -0,0 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
raise NotImplementedError
|
0
api/core/rag/datasource/keyword/__init__.py
Normal file
0
api/core/rag/datasource/keyword/__init__.py
Normal file
0
api/core/rag/datasource/keyword/jieba/__init__.py
Normal file
0
api/core/rag/datasource/keyword/jieba/__init__.py
Normal file
@ -2,11 +2,11 @@ import json
|
||||
from collections import defaultdict
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
|
||||
|
||||
@ -15,59 +15,19 @@ class KeywordTableConfig(BaseModel):
|
||||
max_keywords_per_chunk: int = 10
|
||||
|
||||
|
||||
class KeywordTableIndex(BaseIndex):
|
||||
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
|
||||
class Jieba(BaseKeyword):
|
||||
def __init__(self, dataset: Dataset):
|
||||
super().__init__(dataset)
|
||||
self._config = config
|
||||
self._config = KeywordTableConfig()
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = {}
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self.dataset.id,
|
||||
keyword_table=json.dumps({
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": {}
|
||||
}
|
||||
}, cls=SetEncoder)
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = {}
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self.dataset.id,
|
||||
keyword_table=json.dumps({
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": {}
|
||||
}
|
||||
}, cls=SetEncoder)
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
return self
|
||||
@ -76,8 +36,13 @@ class KeywordTableIndex(BaseIndex):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
keywords_list = kwargs.get('keywords_list', None)
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
if keywords_list:
|
||||
keywords = keywords_list[i]
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
@ -107,20 +72,13 @@ class KeywordTableIndex(BaseIndex):
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
pass
|
||||
|
||||
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
return KeywordTableRetriever(index=self, **kwargs)
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
||||
k = search_kwargs.get('k') if search_kwargs.get('k') else 4
|
||||
k = kwargs.get('top_k', 4)
|
||||
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
|
||||
|
||||
@ -150,12 +108,6 @@ class KeywordTableIndex(BaseIndex):
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
def _save_dataset_keyword_table(self, keyword_table):
|
||||
keyword_table_dict = {
|
||||
'__type__': 'keyword_table',
|
||||
@ -242,6 +194,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
).first()
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.add(document_segment)
|
||||
db.session.commit()
|
||||
|
||||
def create_segment_keywords(self, node_id: str, keywords: list[str]):
|
||||
@ -272,31 +225,6 @@ class KeywordTableIndex(BaseIndex):
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
|
||||
class KeywordTableRetriever(BaseRetriever, BaseModel):
|
||||
index: KeywordTableIndex
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> list[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
return self.index.search(query, **self.search_kwargs)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> list[Document]:
|
||||
raise NotImplementedError("KeywordTableRetriever does not support async")
|
||||
|
||||
|
||||
class SetEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, set):
|
@ -3,7 +3,7 @@ import re
|
||||
import jieba
|
||||
from jieba.analyse import default_tfidf
|
||||
|
||||
from core.index.keyword_table_index.stopwords import STOPWORDS
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
|
||||
class JiebaKeywordTableHandler:
|
@ -3,22 +3,17 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class BaseIndex(ABC):
|
||||
class BaseKeyword(ABC):
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
@abstractmethod
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@ -34,31 +29,18 @@ class BaseIndex(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
def delete_by_document_id(self, document_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts:
|
||||
doc_id = text.metadata['doc_id']
|
60
api/core/rag/datasource/keyword/keyword_factory.py
Normal file
60
api/core/rag/datasource/keyword/keyword_factory.py
Normal file
@ -0,0 +1,60 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.datasource.keyword.jieba.jieba import Jieba
|
||||
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class Keyword:
|
||||
def __init__(self, dataset: Dataset):
|
||||
self._dataset = dataset
|
||||
self._keyword_processor = self._init_keyword()
|
||||
|
||||
def _init_keyword(self) -> BaseKeyword:
|
||||
config = cast(dict, current_app.config)
|
||||
keyword_type = config.get('KEYWORD_STORE')
|
||||
|
||||
if not keyword_type:
|
||||
raise ValueError("Keyword store must be specified.")
|
||||
|
||||
if keyword_type == "jieba":
|
||||
return Jieba(
|
||||
dataset=self._dataset
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
def create(self, texts: list[Document], **kwargs):
|
||||
self._keyword_processor.create(texts, **kwargs)
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
self._keyword_processor.add_texts(texts, **kwargs)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self._keyword_processor.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._keyword_processor.delete_by_ids(ids)
|
||||
|
||||
def delete_by_document_id(self, document_id: str) -> None:
|
||||
self._keyword_processor.delete_by_document_id(document_id)
|
||||
|
||||
def delete(self) -> None:
|
||||
self._keyword_processor.delete()
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
return self._keyword_processor.search(query, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._keyword_processor is not None:
|
||||
method = getattr(self._keyword_processor, name)
|
||||
if callable(method):
|
||||
return method
|
||||
|
||||
raise AttributeError(f"'Keyword' object has no attribute '{name}'")
|
165
api/core/rag/datasource/retrieval_service.py
Normal file
165
api/core/rag/datasource/retrieval_service.py
Normal file
@ -0,0 +1,165 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask_login import current_user
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
|
||||
all_documents = []
|
||||
threads = []
|
||||
# retrieval_model source with keyword
|
||||
if retrival_method == 'keyword_search':
|
||||
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k
|
||||
})
|
||||
threads.append(keyword_thread)
|
||||
keyword_thread.start()
|
||||
# retrieval_model source with semantic
|
||||
if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'score_threshold': score_threshold,
|
||||
'reranking_model': reranking_model,
|
||||
'all_documents': all_documents,
|
||||
'retrival_method': retrival_method
|
||||
})
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval source with full text
|
||||
if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search':
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'retrival_method': retrival_method,
|
||||
'score_threshold': score_threshold,
|
||||
'top_k': top_k,
|
||||
'reranking_model': reranking_model,
|
||||
'all_documents': all_documents
|
||||
})
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if retrival_method == 'hybrid_search':
|
||||
data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False)
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, all_documents: list):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
keyword = Keyword(
|
||||
dataset=dataset
|
||||
)
|
||||
|
||||
documents = keyword.search(
|
||||
query,
|
||||
k=top_k
|
||||
)
|
||||
all_documents.extend(documents)
|
||||
|
||||
@classmethod
|
||||
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, retrival_method: str):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
vector = Vector(
|
||||
dataset=dataset
|
||||
)
|
||||
|
||||
documents = vector.search_by_vector(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
filter={
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
)
|
||||
|
||||
if documents:
|
||||
if reranking_model and retrival_method == 'semantic_search':
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
|
||||
all_documents.extend(data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents)
|
||||
))
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
|
||||
@classmethod
|
||||
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, retrival_method: str):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
vector_processor = Vector(
|
||||
dataset=dataset,
|
||||
)
|
||||
|
||||
documents = vector_processor.search_by_full_text(
|
||||
query,
|
||||
top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
if reranking_model and retrival_method == 'full_text_search':
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
|
||||
all_documents.extend(data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents)
|
||||
))
|
||||
else:
|
||||
all_documents.extend(documents)
|
0
api/core/rag/datasource/vdb/__init__.py
Normal file
0
api/core/rag/datasource/vdb/__init__.py
Normal file
10
api/core/rag/datasource/vdb/field.py
Normal file
10
api/core/rag/datasource/vdb/field.py
Normal file
@ -0,0 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Field(Enum):
|
||||
CONTENT_KEY = "page_content"
|
||||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
VECTOR = "vector"
|
||||
TEXT_KEY = "text"
|
||||
PRIMARY_KEY = " id"
|
0
api/core/rag/datasource/vdb/milvus/__init__.py
Normal file
0
api/core/rag/datasource/vdb/milvus/__init__.py
Normal file
214
api/core/rag/datasource/vdb/milvus/milvus_vector.py
Normal file
214
api/core/rag/datasource/vdb/milvus/milvus_vector.py
Normal file
@ -0,0 +1,214 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pymilvus import MilvusClient, MilvusException, connections
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
secure: bool = False
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
raise ValueError("config MILVUS_HOST is required")
|
||||
if not values['port']:
|
||||
raise ValueError("config MILVUS_PORT is required")
|
||||
if not values['user']:
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values['password']:
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
return {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'user': self.user,
|
||||
'password': self.password,
|
||||
'secure': self.secure
|
||||
}
|
||||
|
||||
|
||||
class MilvusVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: MilvusConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = self._init_client(config)
|
||||
self._consistency_level = 'Session'
|
||||
self._fields = []
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'milvus'
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_params = {
|
||||
'metric_type': 'IP',
|
||||
'index_type': "HNSW",
|
||||
'params': {"M": 8, "efConstruction": 64}
|
||||
}
|
||||
metadatas = [d.metadata for d in texts]
|
||||
|
||||
# Grab the existing collection if it exists
|
||||
from pymilvus import utility
|
||||
alias = uuid4().hex
|
||||
if self._client_config.secure:
|
||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
||||
if not utility.has_collection(self._collection_name, using=alias):
|
||||
self.create_collection(embeddings, metadatas, index_params)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
insert_dict_list = []
|
||||
for i in range(len(documents)):
|
||||
insert_dict = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata
|
||||
}
|
||||
insert_dict_list.append(insert_dict)
|
||||
# Total insert count
|
||||
total_count = len(insert_dict_list)
|
||||
|
||||
pks: list[str] = []
|
||||
|
||||
for i in range(0, total_count, 1000):
|
||||
batch_insert_list = insert_dict_list[i:i + 1000]
|
||||
# Insert into the collection.
|
||||
try:
|
||||
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
||||
pks.extend(ids)
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to insert batch starting at entity: %s/%s", i, total_count
|
||||
)
|
||||
raise e
|
||||
return pks
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
|
||||
ids = self.get_ids_by_metadata_field('document_id', document_id)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["{key}"] == "{value}"',
|
||||
output_fields=["id"])
|
||||
if result:
|
||||
return [item["id"] for item in result]
|
||||
else:
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
|
||||
self._client.delete(collection_name=self._collection_name, pks=doc_ids)
|
||||
|
||||
def delete(self) -> None:
|
||||
|
||||
from pymilvus import utility
|
||||
utility.drop_collection(self._collection_name, None)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] == "{id}"',
|
||||
output_fields=["id"])
|
||||
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
|
||||
# Set search parameters.
|
||||
results = self._client.search(collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
limit=kwargs.get('top_k', 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results[0]:
|
||||
metadata = result['entity'].get(Field.METADATA_KEY.value)
|
||||
metadata['score'] = result['distance']
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
if result['distance'] > score_threshold:
|
||||
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
|
||||
metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# milvus/zilliz doesn't support bm25 search
|
||||
return []
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
) -> str:
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema
|
||||
from pymilvus.orm.types import infer_dtype_bydata
|
||||
|
||||
# Determine embedding dim
|
||||
dim = len(embeddings[0])
|
||||
fields = []
|
||||
if metadatas:
|
||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field
|
||||
fields.append(
|
||||
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
|
||||
)
|
||||
# Create the primary key field
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
|
||||
)
|
||||
)
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(
|
||||
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
|
||||
)
|
||||
|
||||
# Create the schema for the collection
|
||||
schema = CollectionSchema(fields)
|
||||
|
||||
for x in schema.fields:
|
||||
self._fields.append(x.name)
|
||||
# Since primary field is auto-id, no need to track it
|
||||
self._fields.remove(Field.PRIMARY_KEY.value)
|
||||
|
||||
# Create the collection
|
||||
collection_name = self._collection_name
|
||||
self._client.create_collection_with_schema(collection_name=collection_name,
|
||||
schema=schema, index_param=index_params,
|
||||
consistency_level=self._consistency_level)
|
||||
return collection_name
|
||||
|
||||
def _init_client(self, config) -> MilvusClient:
|
||||
if config.secure:
|
||||
uri = "https://" + str(config.host) + ":" + str(config.port)
|
||||
else:
|
||||
uri = "http://" + str(config.host) + ":" + str(config.port)
|
||||
client = MilvusClient(uri=uri, user=config.user, password=config.password)
|
||||
return client
|
0
api/core/rag/datasource/vdb/qdrant/__init__.py
Normal file
0
api/core/rag/datasource/vdb/qdrant/__init__.py
Normal file
360
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
Normal file
360
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
Normal file
@ -0,0 +1,360 @@
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import qdrant_client
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http import models as rest
|
||||
from qdrant_client.http.models import (
|
||||
FilterSelector,
|
||||
HnswConfigDiff,
|
||||
PayloadSchemaType,
|
||||
TextIndexParams,
|
||||
TextIndexType,
|
||||
TokenizerType,
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
from qdrant_client.conversions import common_types
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
DictFilter = dict[str, Union[str, int, bool, dict, list]]
|
||||
MetadataFilter = Union[DictFilter, common_types.Filter]
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
timeout: float = 20
|
||||
root_path: Optional[str]
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith('path:'):
|
||||
path = self.endpoint.replace('path:', '')
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {
|
||||
'path': path
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'url': self.endpoint,
|
||||
'api_key': self.api_key,
|
||||
'timeout': self.timeout
|
||||
}
|
||||
|
||||
|
||||
class QdrantVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
|
||||
self._distance_func = distance_func.upper()
|
||||
self._group_id = group_id
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'qdrant'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if texts:
|
||||
# get embedding vector size
|
||||
vector_size = len(embeddings[0])
|
||||
# get collection name
|
||||
collection_name = self._collection_name
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
all_collection_name = []
|
||||
collections_response = self._client.get_collections()
|
||||
collection_list = collections_response.collections
|
||||
for collection in collection_list:
|
||||
all_collection_name.append(collection.name)
|
||||
if collection_name not in all_collection_name:
|
||||
# create collection
|
||||
self.create_collection(collection_name, vector_size)
|
||||
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(self, collection_name: str, vector_size: int):
|
||||
from qdrant_client.http import models as rest
|
||||
vectors_config = rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[self._distance_func],
|
||||
)
|
||||
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False)
|
||||
self._client.recreate_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vectors_config,
|
||||
hnsw_config=hnsw_config,
|
||||
timeout=int(self._client_config.timeout),
|
||||
)
|
||||
|
||||
# create payload index
|
||||
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
|
||||
field_schema=PayloadSchemaType.KEYWORD,
|
||||
field_type=PayloadSchemaType.KEYWORD)
|
||||
# creat full text index
|
||||
text_index_params = TextIndexParams(
|
||||
type=TextIndexType.TEXT,
|
||||
tokenizer=TokenizerType.MULTILINGUAL,
|
||||
min_token_len=2,
|
||||
max_token_len=20,
|
||||
lowercase=True
|
||||
)
|
||||
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
|
||||
field_schema=text_index_params)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
added_ids = []
|
||||
for batch_ids, points in self._generate_rest_batches(
|
||||
texts, embeddings, metadatas, uuids, 64, self._group_id
|
||||
):
|
||||
self._client.upsert(
|
||||
collection_name=self._collection_name, points=points
|
||||
)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
return added_ids
|
||||
|
||||
def _generate_rest_batches(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
|
||||
from qdrant_client.http import models as rest
|
||||
texts_iterator = iter(texts)
|
||||
embeddings_iterator = iter(embeddings)
|
||||
metadatas_iterator = iter(metadatas or [])
|
||||
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
||||
while batch_texts := list(islice(texts_iterator, batch_size)):
|
||||
# Take the corresponding metadata and id for each text in a batch
|
||||
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
||||
batch_ids = list(islice(ids_iterator, batch_size))
|
||||
|
||||
# Generate the embeddings for all the texts in a batch
|
||||
batch_embeddings = list(islice(embeddings_iterator, batch_size))
|
||||
|
||||
points = [
|
||||
rest.PointStruct(
|
||||
id=point_id,
|
||||
vector=vector,
|
||||
payload=payload,
|
||||
)
|
||||
for point_id, vector, payload in zip(
|
||||
batch_ids,
|
||||
batch_embeddings,
|
||||
self._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
group_id,
|
||||
Field.GROUP_KEY.value,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
yield batch_ids, points
|
||||
|
||||
@classmethod
|
||||
def _build_payloads(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]],
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
group_payload_key: str
|
||||
) -> list[dict]:
|
||||
payloads = []
|
||||
for i, text in enumerate(texts):
|
||||
if text is None:
|
||||
raise ValueError(
|
||||
"At least one of the texts is None. Please remove it before "
|
||||
"calling .from_texts or .add_texts on Qdrant instance."
|
||||
)
|
||||
metadata = metadatas[i] if metadatas is not None else None
|
||||
payloads.append(
|
||||
{
|
||||
content_payload_key: text,
|
||||
metadata_payload_key: metadata,
|
||||
group_payload_key: group_id
|
||||
}
|
||||
)
|
||||
|
||||
return payloads
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
|
||||
def delete(self):
|
||||
from qdrant_client.http import models
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
from qdrant_client.http import models
|
||||
for node_id in ids:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
response = self._client.retrieve(
|
||||
collection_name=self._collection_name,
|
||||
ids=[id]
|
||||
)
|
||||
|
||||
return len(response) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from qdrant_client.http import models
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=filter,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
score_threshold=kwargs.get("score_threshold", .0)
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
if result.score > score_threshold:
|
||||
metadata['score'] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Return docs most similar by bm25.
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance for each.
|
||||
"""
|
||||
from qdrant_client.http import models
|
||||
scroll_filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
models.FieldCondition(
|
||||
key="page_content",
|
||||
match=models.MatchText(text=query),
|
||||
)
|
||||
]
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=kwargs.get('top_k', 2),
|
||||
with_payload=True,
|
||||
with_vectors=True
|
||||
|
||||
)
|
||||
results = response[0]
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
documents.append(self._document_from_scored_point(
|
||||
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
|
||||
))
|
||||
|
||||
return documents
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client = cast(QdrantLocal, self._client)
|
||||
self._client._load()
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
cls,
|
||||
scored_point: Any,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
) -> Document:
|
||||
return Document(
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||
)
|
62
api/core/rag/datasource/vdb/vector_base.py
Normal file
62
api/core/rag/datasource/vdb/vector_base.py
Normal file
@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class BaseVector(ABC):
|
||||
|
||||
def __init__(self, collection_name: str):
|
||||
self._collection_name = collection_name
|
||||
|
||||
@abstractmethod
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def text_exists(self, id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_vector(
|
||||
self,
|
||||
query_vector: list[float],
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_full_text(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts:
|
||||
doc_id = text.metadata['doc_id']
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||
return [text.metadata['doc_id'] for text in texts]
|
171
api/core/rag/datasource/vdb/vector_factory.py
Normal file
171
api/core/rag/datasource/vdb/vector_factory.py
Normal file
@ -0,0 +1,171 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list = None):
|
||||
if attributes is None:
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
self._dataset = dataset
|
||||
self._embeddings = self._get_embeddings()
|
||||
self._attributes = attributes
|
||||
self._vector_processor = self._init_vector()
|
||||
|
||||
def _init_vector(self) -> BaseVector:
|
||||
config = cast(dict, current_app.config)
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict['type']
|
||||
|
||||
if not vector_type:
|
||||
raise ValueError("Vector store must be specified.")
|
||||
|
||||
if vector_type == "weaviate":
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||
if self._dataset.index_struct_dict:
|
||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
return WeaviateVector(
|
||||
collection_name=collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint=config.get('WEAVIATE_ENDPOINT'),
|
||||
api_key=config.get('WEAVIATE_API_KEY'),
|
||||
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
||||
),
|
||||
attributes=self._attributes
|
||||
)
|
||||
elif vector_type == "qdrant":
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
|
||||
if self._dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
if self._dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
return QdrantVector(
|
||||
collection_name=collection_name,
|
||||
group_id=self._dataset.id,
|
||||
config=QdrantConfig(
|
||||
endpoint=config.get('QDRANT_URL'),
|
||||
api_key=config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path,
|
||||
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
|
||||
)
|
||||
)
|
||||
elif vector_type == "milvus":
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
|
||||
if self._dataset.index_struct_dict:
|
||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
return MilvusVector(
|
||||
collection_name=collection_name,
|
||||
config=MilvusConfig(
|
||||
host=config.get('MILVUS_HOST'),
|
||||
port=config.get('MILVUS_PORT'),
|
||||
user=config.get('MILVUS_USER'),
|
||||
password=config.get('MILVUS_PASSWORD'),
|
||||
secure=config.get('MILVUS_SECURE'),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
def create(self, texts: list = None, **kwargs):
|
||||
if texts:
|
||||
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
|
||||
self._vector_processor.create(
|
||||
texts=texts,
|
||||
embeddings=embeddings,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def add_texts(self, documents: list[Document], **kwargs):
|
||||
if kwargs.get('duplicate_check', False):
|
||||
documents = self._filter_duplicate_texts(documents)
|
||||
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
|
||||
self._vector_processor.add_texts(
|
||||
documents=documents,
|
||||
embeddings=embeddings,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self._vector_processor.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._vector_processor.delete_by_ids(ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._vector_processor.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
query_vector = self._embeddings.embed_query(query)
|
||||
return self._vector_processor.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_full_text(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
return self._vector_processor.search_by_full_text(query, **kwargs)
|
||||
|
||||
def delete(self) -> None:
|
||||
self._vector_processor.delete()
|
||||
|
||||
def _get_embeddings(self) -> Embeddings:
|
||||
model_manager = ModelManager()
|
||||
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=self._dataset.embedding_model
|
||||
|
||||
)
|
||||
return CacheEmbedding(embedding_model)
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts:
|
||||
doc_id = text.metadata['doc_id']
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._vector_processor is not None:
|
||||
method = getattr(self._vector_processor, name)
|
||||
if callable(method):
|
||||
return method
|
||||
|
||||
raise AttributeError(f"'vector_processor' object has no attribute '{name}'")
|
0
api/core/rag/datasource/vdb/weaviate/__init__.py
Normal file
0
api/core/rag/datasource/vdb/weaviate/__init__.py
Normal file
235
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
Normal file
235
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
Normal file
@ -0,0 +1,235 @@
|
||||
import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
import weaviate
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
return values
|
||||
|
||||
|
||||
class WeaviateVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
|
||||
super().__init__(collection_name)
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
|
||||
|
||||
weaviate.connect.connection.has_grpc = False
|
||||
|
||||
try:
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint,
|
||||
auth_client_secret=auth_config,
|
||||
timeout_config=(5, 60),
|
||||
startup_period=None
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
|
||||
client.batch.configure(
|
||||
# `batch_size` takes an `int` value to enable auto-batching
|
||||
# (`None` is used for manual batching)
|
||||
batch_size=config.batch_size,
|
||||
# dynamically update the `batch_size` based on import speed
|
||||
dynamic=True,
|
||||
# `timeout_retries` takes an `int` value to retry on time outs
|
||||
timeout_retries=3,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'weaviate'
|
||||
|
||||
def get_collection_name(self, dataset: Dataset) -> str:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
|
||||
schema = self._default_schema(self._collection_name)
|
||||
|
||||
# check whether the index already exists
|
||||
if not self._client.schema.contains(schema):
|
||||
# create collection
|
||||
self._client.schema.create_class(schema)
|
||||
# create vector
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
ids = []
|
||||
|
||||
with self._client.batch as batch:
|
||||
for i, text in enumerate(texts):
|
||||
data_properties = {Field.TEXT_KEY.value: text}
|
||||
if metadatas is not None:
|
||||
for key, val in metadatas[i].items():
|
||||
data_properties[key] = self._json_serializable(val)
|
||||
|
||||
batch.add_data_object(
|
||||
data_object=data_properties,
|
||||
class_name=self._collection_name,
|
||||
uuid=uuids[i],
|
||||
vector=embeddings[i] if embeddings else None,
|
||||
)
|
||||
ids.append(uuids[i])
|
||||
return ids
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
where_filter = {
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
}
|
||||
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
|
||||
def delete(self):
|
||||
self._client.schema.delete_class(self._collection_name)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
collection_name = self._collection_name
|
||||
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": id,
|
||||
}).with_limit(1).do()
|
||||
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
|
||||
entries = result["data"]["Get"][collection_name]
|
||||
if len(entries) == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.data_object.delete(
|
||||
ids,
|
||||
class_name=self._collection_name
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
collection_name = self._collection_name
|
||||
properties = self._attributes
|
||||
properties.append(Field.TEXT_KEY.value)
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
|
||||
vector = {"vector": query_vector}
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
result = (
|
||||
query_obj.with_near_vector(vector)
|
||||
.with_limit(kwargs.get("top_k", 4))
|
||||
.with_additional(["vector", "distance"])
|
||||
.do()
|
||||
)
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
|
||||
docs_and_scores = []
|
||||
for res in result["data"]["Get"][collection_name]:
|
||||
text = res.pop(Field.TEXT_KEY.value)
|
||||
score = 1 - res["_additional"]["distance"]
|
||||
docs_and_scores.append((Document(page_content=text, metadata=res), score))
|
||||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
doc.metadata['score'] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Return docs using BM25F.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
collection_name = self._collection_name
|
||||
content: dict[str, Any] = {"concepts": [query]}
|
||||
properties = self._attributes
|
||||
properties.append(Field.TEXT_KEY.value)
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
properties = ['text']
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][collection_name]:
|
||||
text = res.pop(Field.TEXT_KEY.value)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
||||
|
||||
def _default_schema(self, index_name: str) -> dict:
|
||||
return {
|
||||
"class": index_name,
|
||||
"properties": [
|
||||
{
|
||||
"name": "text",
|
||||
"dataType": ["text"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _json_serializable(self, value: Any) -> Any:
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value.isoformat()
|
||||
return value
|
166
api/core/rag/extractor/blod/blod.py
Normal file
166
api/core/rag/extractor/blod/blod.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""Schema for Blobs and Blob Loaders.
|
||||
|
||||
The goal is to facilitate decoupling of content loading from content parsing code.
|
||||
|
||||
In addition, content loading code should provide a lazy loading interface by default.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import mimetypes
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Iterable, Mapping
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
PathLike = Union[str, PurePath]
|
||||
|
||||
|
||||
class Blob(BaseModel):
|
||||
"""A blob is used to represent raw data by either reference or value.
|
||||
|
||||
Provides an interface to materialize the blob in different representations, and
|
||||
help to decouple the development of data loaders from the downstream parsing of
|
||||
the raw data.
|
||||
|
||||
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
|
||||
"""
|
||||
|
||||
data: Union[bytes, str, None] # Raw data
|
||||
mimetype: Optional[str] = None # Not to be confused with a file extension
|
||||
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
|
||||
# Location where the original content was found
|
||||
# Represent location on the local file system
|
||||
# Useful for situations where downstream code assumes it must work with file paths
|
||||
# rather than in-memory content.
|
||||
path: Optional[PathLike] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
frozen = True
|
||||
|
||||
@property
|
||||
def source(self) -> Optional[str]:
|
||||
"""The source location of the blob as string if known otherwise none."""
|
||||
return str(self.path) if self.path else None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Verify that either data or path is provided."""
|
||||
if "data" not in values and "path" not in values:
|
||||
raise ValueError("Either data or path must be provided")
|
||||
return values
|
||||
|
||||
def as_string(self) -> str:
|
||||
"""Read data as a string."""
|
||||
if self.data is None and self.path:
|
||||
with open(str(self.path), encoding=self.encoding) as f:
|
||||
return f.read()
|
||||
elif isinstance(self.data, bytes):
|
||||
return self.data.decode(self.encoding)
|
||||
elif isinstance(self.data, str):
|
||||
return self.data
|
||||
else:
|
||||
raise ValueError(f"Unable to get string for blob {self}")
|
||||
|
||||
def as_bytes(self) -> bytes:
|
||||
"""Read data as bytes."""
|
||||
if isinstance(self.data, bytes):
|
||||
return self.data
|
||||
elif isinstance(self.data, str):
|
||||
return self.data.encode(self.encoding)
|
||||
elif self.data is None and self.path:
|
||||
with open(str(self.path), "rb") as f:
|
||||
return f.read()
|
||||
else:
|
||||
raise ValueError(f"Unable to get bytes for blob {self}")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
|
||||
"""Read data as a byte stream."""
|
||||
if isinstance(self.data, bytes):
|
||||
yield BytesIO(self.data)
|
||||
elif self.data is None and self.path:
|
||||
with open(str(self.path), "rb") as f:
|
||||
yield f
|
||||
else:
|
||||
raise NotImplementedError(f"Unable to convert blob {self}")
|
||||
|
||||
@classmethod
|
||||
def from_path(
|
||||
cls,
|
||||
path: PathLike,
|
||||
*,
|
||||
encoding: str = "utf-8",
|
||||
mime_type: Optional[str] = None,
|
||||
guess_type: bool = True,
|
||||
) -> Blob:
|
||||
"""Load the blob from a path like object.
|
||||
|
||||
Args:
|
||||
path: path like object to file to be read
|
||||
encoding: Encoding to use if decoding the bytes into a string
|
||||
mime_type: if provided, will be set as the mime-type of the data
|
||||
guess_type: If True, the mimetype will be guessed from the file extension,
|
||||
if a mime-type was not provided
|
||||
|
||||
Returns:
|
||||
Blob instance
|
||||
"""
|
||||
if mime_type is None and guess_type:
|
||||
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None
|
||||
else:
|
||||
_mimetype = mime_type
|
||||
# We do not load the data immediately, instead we treat the blob as a
|
||||
# reference to the underlying data.
|
||||
return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path)
|
||||
|
||||
@classmethod
|
||||
def from_data(
|
||||
cls,
|
||||
data: Union[str, bytes],
|
||||
*,
|
||||
encoding: str = "utf-8",
|
||||
mime_type: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
) -> Blob:
|
||||
"""Initialize the blob from in-memory data.
|
||||
|
||||
Args:
|
||||
data: the in-memory data associated with the blob
|
||||
encoding: Encoding to use if decoding the bytes into a string
|
||||
mime_type: if provided, will be set as the mime-type of the data
|
||||
path: if provided, will be set as the source from which the data came
|
||||
|
||||
Returns:
|
||||
Blob instance
|
||||
"""
|
||||
return cls(data=data, mimetype=mime_type, encoding=encoding, path=path)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Define the blob representation."""
|
||||
str_repr = f"Blob {id(self)}"
|
||||
if self.source:
|
||||
str_repr += f" {self.source}"
|
||||
return str_repr
|
||||
|
||||
|
||||
class BlobLoader(ABC):
|
||||
"""Abstract interface for blob loaders implementation.
|
||||
|
||||
Implementer should be able to load raw content from a datasource system according
|
||||
to some criteria and return the raw content lazily as a stream of blobs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def yield_blobs(
|
||||
self,
|
||||
) -> Iterable[Blob]:
|
||||
"""A lazy loader for raw data represented by LangChain's Blob object.
|
||||
|
||||
Returns:
|
||||
A generator over blobs
|
||||
"""
|
71
api/core/rag/extractor/csv_extractor.py
Normal file
71
api/core/rag/extractor/csv_extractor.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
import csv
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class CSVExtractor(BaseExtractor):
|
||||
"""Load CSV files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
self.source_column = source_column
|
||||
self.csv_args = csv_args or {}
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load data into document objects."""
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_filze_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
|
||||
return docs
|
||||
|
||||
def _read_from_file(self, csvfile) -> list[Document]:
|
||||
docs = []
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv_reader):
|
||||
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
|
||||
try:
|
||||
source = (
|
||||
row[self.source_column]
|
||||
if self.source_column is not None
|
||||
else ''
|
||||
)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Source column '{self.source_column}' not found in CSV file."
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
6
api/core/rag/extractor/entity/datasource_type.py
Normal file
6
api/core/rag/extractor/entity/datasource_type.py
Normal file
@ -0,0 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DatasourceType(Enum):
|
||||
FILE = "upload_file"
|
||||
NOTION = "notion_import"
|
36
api/core/rag/extractor/entity/extract_setting.py
Normal file
36
api/core/rag/extractor/entity/extract_setting.py
Normal file
@ -0,0 +1,36 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.dataset import Document
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class NotionInfo(BaseModel):
|
||||
"""
|
||||
Notion import info.
|
||||
"""
|
||||
notion_workspace_id: str
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
document: Document = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class ExtractSetting(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
datasource_type: str
|
||||
upload_file: UploadFile = None
|
||||
notion_info: NotionInfo = None
|
||||
document_model: str = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
@ -1,14 +1,14 @@
|
||||
import logging
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from typing import Optional
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
from openpyxl.reader.excel import load_workbook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class ExcelLoader(BaseLoader):
|
||||
"""Load xlxs files.
|
||||
class ExcelExtractor(BaseExtractor):
|
||||
"""Load Excel files.
|
||||
|
||||
|
||||
Args:
|
||||
@ -16,13 +16,18 @@ class ExcelLoader(BaseLoader):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load from file path."""
|
||||
data = []
|
||||
keys = []
|
||||
wb = load_workbook(filename=self._file_path, read_only=True)
|
139
api/core/rag/extractor/extract_processor.py
Normal file
139
api/core/rag/extractor/extract_processor.py
Normal file
@ -0,0 +1,139 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.extractor.csv_extractor import CSVExtractor
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.excel_extractor import ExcelExtractor
|
||||
from core.rag.extractor.html_extractor import HtmlExtractor
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from core.rag.extractor.pdf_extractor import PdfExtractor
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
|
||||
from core.rag.extractor.word_extractor import WordExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
|
||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
|
||||
|
||||
class ExtractProcessor:
|
||||
@classmethod
|
||||
def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \
|
||||
-> Union[list[Document], str]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=upload_file,
|
||||
document_model='text_model'
|
||||
)
|
||||
if return_text:
|
||||
delimiter = '\n'
|
||||
return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)])
|
||||
else:
|
||||
return cls.extract(extract_setting, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
document_model='text_model'
|
||||
)
|
||||
if return_text:
|
||||
delimiter = '\n'
|
||||
return delimiter.join([document.page_content for document in cls.extract(
|
||||
extract_setting=extract_setting, file_path=file_path)])
|
||||
else:
|
||||
return cls.extract(extract_setting=extract_setting, file_path=file_path)
|
||||
|
||||
@classmethod
|
||||
def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False,
|
||||
file_path: str = None) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE.value:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
if not file_path:
|
||||
upload_file: UploadFile = extract_setting.upload_file
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx':
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
extractor = PdfExtractor(file_path)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \
|
||||
else MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in ['.docx']:
|
||||
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.csv':
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension == '.msg':
|
||||
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.eml':
|
||||
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.ppt':
|
||||
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.pptx':
|
||||
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.xml':
|
||||
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \
|
||||
else TextExtractor(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
if file_extension == '.xlsx':
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
extractor = PdfExtractor(file_path)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in ['.docx']:
|
||||
extractor = WordExtractor(file_path)
|
||||
elif file_extension == '.csv':
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
extractor = TextExtractor(file_path, autodetect_encoding=True)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
|
||||
notion_obj_id=extract_setting.notion_info.notion_obj_id,
|
||||
notion_page_type=extract_setting.notion_info.notion_page_type,
|
||||
document_model=extract_setting.notion_info.document
|
||||
)
|
||||
return extractor.extract()
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")
|
12
api/core/rag/extractor/extractor_base.py
Normal file
12
api/core/rag/extractor/extractor_base.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseExtractor(ABC):
|
||||
"""Interface for extract files.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self):
|
||||
raise NotImplementedError
|
||||
|
46
api/core/rag/extractor/helpers.py
Normal file
46
api/core/rag/extractor/helpers.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Document loader helpers."""
|
||||
|
||||
import concurrent.futures
|
||||
from typing import NamedTuple, Optional, cast
|
||||
|
||||
|
||||
class FileEncoding(NamedTuple):
|
||||
"""A file encoding as the NamedTuple."""
|
||||
|
||||
encoding: Optional[str]
|
||||
"""The encoding of the file."""
|
||||
confidence: float
|
||||
"""The confidence of the encoding."""
|
||||
language: Optional[str]
|
||||
"""The language of the file."""
|
||||
|
||||
|
||||
def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]:
|
||||
"""Try to detect the file encoding.
|
||||
|
||||
Returns a list of `FileEncoding` tuples with the detected encodings ordered
|
||||
by confidence.
|
||||
|
||||
Args:
|
||||
file_path: The path to the file to detect the encoding for.
|
||||
timeout: The timeout in seconds for the encoding detection.
|
||||
"""
|
||||
import chardet
|
||||
|
||||
def read_and_detect(file_path: str) -> list[dict]:
|
||||
with open(file_path, "rb") as f:
|
||||
rawdata = f.read()
|
||||
return cast(list[dict], chardet.detect_all(rawdata))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(read_and_detect, file_path)
|
||||
try:
|
||||
encodings = future.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Timeout reached while detecting encoding for {file_path}"
|
||||
)
|
||||
|
||||
if all(encoding["encoding"] is None for encoding in encodings):
|
||||
raise RuntimeError(f"Could not detect encoding for {file_path}")
|
||||
return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]
|
@ -1,51 +1,55 @@
|
||||
import csv
|
||||
import logging
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from typing import Optional
|
||||
|
||||
from langchain.document_loaders import CSVLoader as LCCSVLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class CSVLoader(LCCSVLoader):
|
||||
class HtmlExtractor(BaseExtractor):
|
||||
"""Load html files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
self.file_path = file_path
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
self.source_column = source_column
|
||||
self.encoding = encoding
|
||||
self.csv_args = csv_args or {}
|
||||
self.autodetect_encoding = autodetect_encoding
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load data into document objects."""
|
||||
try:
|
||||
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
||||
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
except UnicodeDecodeError as e:
|
||||
if self.autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self.file_path)
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
logger.debug("Trying encoding: ", encoding.encoding)
|
||||
try:
|
||||
with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
|
||||
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self.file_path}") from e
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
|
||||
return docs
|
||||
|
||||
def _read_from_file(self, csvfile):
|
||||
def _read_from_file(self, csvfile) -> list[Document]:
|
||||
docs = []
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv_reader):
|
@ -1,39 +1,27 @@
|
||||
import logging
|
||||
"""Abstract interface for document loader implementations."""
|
||||
import re
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class MarkdownLoader(BaseLoader):
|
||||
"""Load md files.
|
||||
class MarkdownExtractor(BaseExtractor):
|
||||
"""Load Markdown files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
|
||||
remove_hyperlinks: Whether to remove hyperlinks from the text.
|
||||
|
||||
remove_images: Whether to remove images from the text.
|
||||
|
||||
encoding: File encoding to use. If `None`, the file will be loaded
|
||||
with the default system encoding.
|
||||
|
||||
autodetect_encoding: Whether to try to autodetect the file encoding
|
||||
if the specified encoding fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
remove_hyperlinks: bool = True,
|
||||
remove_images: bool = True,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
self,
|
||||
file_path: str,
|
||||
remove_hyperlinks: bool = True,
|
||||
remove_images: bool = True,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
@ -42,7 +30,8 @@ class MarkdownLoader(BaseLoader):
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load from file path."""
|
||||
tups = self.parse_tups(self._file_path)
|
||||
documents = []
|
||||
for header, value in tups:
|
||||
@ -113,7 +102,6 @@ class MarkdownLoader(BaseLoader):
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(filepath)
|
||||
for encoding in detected_encodings:
|
||||
logger.debug("Trying encoding: ", encoding.encoding)
|
||||
try:
|
||||
with open(filepath, encoding=encoding.encoding) as f:
|
||||
content = f.read()
|
@ -4,9 +4,10 @@ from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from flask_login import current_user
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document as DocumentModel
|
||||
from models.source import DataSourceBinding
|
||||
@ -22,52 +23,37 @@ RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
||||
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
|
||||
|
||||
|
||||
class NotionLoader(BaseLoader):
|
||||
class NotionExtractor(BaseExtractor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
notion_access_token: str,
|
||||
notion_workspace_id: str,
|
||||
notion_obj_id: str,
|
||||
notion_page_type: str,
|
||||
document_model: Optional[DocumentModel] = None
|
||||
document_model: Optional[DocumentModel] = None,
|
||||
notion_access_token: Optional[str] = None
|
||||
):
|
||||
self._notion_access_token = None
|
||||
self._document_model = document_model
|
||||
self._notion_workspace_id = notion_workspace_id
|
||||
self._notion_obj_id = notion_obj_id
|
||||
self._notion_page_type = notion_page_type
|
||||
self._notion_access_token = notion_access_token
|
||||
if notion_access_token:
|
||||
self._notion_access_token = notion_access_token
|
||||
else:
|
||||
self._notion_access_token = self._get_access_token(current_user.current_tenant_id,
|
||||
self._notion_workspace_id)
|
||||
if not self._notion_access_token:
|
||||
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
|
||||
if integration_token is None:
|
||||
raise ValueError(
|
||||
"Must specify `integration_token` or set environment "
|
||||
"variable `NOTION_INTEGRATION_TOKEN`."
|
||||
)
|
||||
|
||||
if not self._notion_access_token:
|
||||
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
|
||||
if integration_token is None:
|
||||
raise ValueError(
|
||||
"Must specify `integration_token` or set environment "
|
||||
"variable `NOTION_INTEGRATION_TOKEN`."
|
||||
)
|
||||
self._notion_access_token = integration_token
|
||||
|
||||
self._notion_access_token = integration_token
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document_model: DocumentModel):
|
||||
data_source_info = document_model.data_source_info_dict
|
||||
if not data_source_info or 'notion_page_id' not in data_source_info \
|
||||
or 'notion_workspace_id' not in data_source_info:
|
||||
raise ValueError("no notion page found")
|
||||
|
||||
notion_workspace_id = data_source_info['notion_workspace_id']
|
||||
notion_obj_id = data_source_info['notion_page_id']
|
||||
notion_page_type = data_source_info['type']
|
||||
notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
|
||||
|
||||
return cls(
|
||||
notion_access_token=notion_access_token,
|
||||
notion_workspace_id=notion_workspace_id,
|
||||
notion_obj_id=notion_obj_id,
|
||||
notion_page_type=notion_page_type,
|
||||
document_model=document_model
|
||||
)
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
def extract(self) -> list[Document]:
|
||||
self.update_last_edited_time(
|
||||
self._document_model
|
||||
)
|
72
api/core/rag/extractor/pdf_extractor.py
Normal file
72
api/core/rag/extractor/pdf_extractor.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.blod.blod import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class PdfExtractor(BaseExtractor):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
file_cache_key: Optional[str] = None
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
try:
|
||||
text = storage.load(self._file_cache_key).decode('utf-8')
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
documents = list(self.load())
|
||||
text_list = []
|
||||
for document in documents:
|
||||
text_list.append(document.page_content)
|
||||
text = "\n\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||
|
||||
return documents
|
||||
|
||||
def load(
|
||||
self,
|
||||
) -> Iterator[Document]:
|
||||
"""Lazy load given path as pages."""
|
||||
blob = Blob.from_path(self._file_path)
|
||||
yield from self.parse(blob)
|
||||
|
||||
def parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
import pypdfium2
|
||||
|
||||
with blob.as_bytes_io() as file_path:
|
||||
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
||||
try:
|
||||
for page_number, page in enumerate(pdf_reader):
|
||||
text_page = page.get_textpage()
|
||||
content = text_page.get_text_range()
|
||||
text_page.close()
|
||||
page.close()
|
||||
metadata = {"source": blob.source, "page": page_number}
|
||||
yield Document(page_content=content, metadata=metadata)
|
||||
finally:
|
||||
pdf_reader.close()
|
50
api/core/rag/extractor/text_extractor.py
Normal file
50
api/core/rag/extractor/text_extractor.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class TextExtractor(BaseExtractor):
|
||||
"""Load text files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load from file path."""
|
||||
text = ""
|
||||
try:
|
||||
with open(self._file_path, encoding=self._encoding) as f:
|
||||
text = f.read()
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(self._file_path, encoding=encoding.encoding) as f:
|
||||
text = f.read()
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
|
||||
metadata = {"source": self._file_path}
|
||||
return [Document(page_content=text, metadata=metadata)]
|
@ -0,0 +1,61 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredWordExtractor(BaseExtractor):
|
||||
"""Loader that uses unstructured to load word documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.__version__ import __version__ as __unstructured_version__
|
||||
from unstructured.file_utils.filetype import FileType, detect_filetype
|
||||
|
||||
unstructured_version = tuple(
|
||||
[int(x) for x in __unstructured_version__.split(".")]
|
||||
)
|
||||
# check the file extension
|
||||
try:
|
||||
import magic # noqa: F401
|
||||
|
||||
is_doc = detect_filetype(self._file_path) == FileType.DOC
|
||||
except ImportError:
|
||||
_, extension = os.path.splitext(str(self._file_path))
|
||||
is_doc = extension == ".doc"
|
||||
|
||||
if is_doc and unstructured_version < (0, 4, 11):
|
||||
raise ValueError(
|
||||
f"You are on unstructured version {__unstructured_version__}. "
|
||||
"Partitioning .doc files is only supported in unstructured>=0.4.11. "
|
||||
"Please upgrade the unstructured package and try again."
|
||||
)
|
||||
|
||||
if is_doc:
|
||||
from unstructured.partition.doc import partition_doc
|
||||
|
||||
elements = partition_doc(filename=self._file_path)
|
||||
else:
|
||||
from unstructured.partition.docx import partition_docx
|
||||
|
||||
elements = partition_docx(filename=self._file_path)
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
return documents
|
@ -2,13 +2,14 @@ import base64
|
||||
import logging
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredEmailLoader(BaseLoader):
|
||||
class UnstructuredEmailExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
@ -23,7 +24,7 @@ class UnstructuredEmailLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.email import partition_email
|
||||
elements = partition_email(filename=self._file_path, api_url=self._api_url)
|
||||
|
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredMarkdownLoader(BaseLoader):
|
||||
class UnstructuredMarkdownExtractor(BaseExtractor):
|
||||
"""Load md files.
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.md import partition_md
|
||||
|
||||
elements = partition_md(filename=self._file_path, api_url=self._api_url)
|
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredMsgLoader(BaseLoader):
|
||||
class UnstructuredMsgExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
|
@ -1,11 +1,12 @@
|
||||
import logging
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UnstructuredPPTLoader(BaseLoader):
|
||||
|
||||
class UnstructuredPPTExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
@ -14,15 +15,15 @@ class UnstructuredPPTLoader(BaseLoader):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
|
@ -1,10 +1,12 @@
|
||||
import logging
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class UnstructuredPPTXLoader(BaseLoader):
|
||||
|
||||
|
||||
class UnstructuredPPTXExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
@ -13,15 +15,15 @@ class UnstructuredPPTXLoader(BaseLoader):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
|
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredTextLoader(BaseLoader):
|
||||
class UnstructuredTextExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.text import partition_text
|
||||
|
||||
elements = partition_text(filename=self._file_path, api_url=self._api_url)
|
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredXmlLoader(BaseLoader):
|
||||
class UnstructuredXmlExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.xml import partition_xml
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
|
62
api/core/rag/extractor/word_extractor.py
Normal file
62
api/core/rag/extractor/word_extractor.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
import os
|
||||
import tempfile
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class WordExtractor(BaseExtractor):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
"""Initialize with file path."""
|
||||
self.file_path = file_path
|
||||
if "~" in self.file_path:
|
||||
self.file_path = os.path.expanduser(self.file_path)
|
||||
|
||||
# If the file is a web path, download it to a temporary file, and use that
|
||||
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
|
||||
r = requests.get(self.file_path)
|
||||
|
||||
if r.status_code != 200:
|
||||
raise ValueError(
|
||||
"Check the url of your file; returned status code %s"
|
||||
% r.status_code
|
||||
)
|
||||
|
||||
self.web_path = self.file_path
|
||||
self.temp_file = tempfile.NamedTemporaryFile()
|
||||
self.temp_file.write(r.content)
|
||||
self.file_path = self.temp_file.name
|
||||
elif not os.path.isfile(self.file_path):
|
||||
raise ValueError("File path %s is not a valid file or url" % self.file_path)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, "temp_file"):
|
||||
self.temp_file.close()
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
import docx2txt
|
||||
|
||||
return [
|
||||
Document(
|
||||
page_content=docx2txt.process(self.file_path),
|
||||
metadata={"source": self.file_path},
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_url(url: str) -> bool:
|
||||
"""Check if the url is valid."""
|
||||
parsed = urlparse(url)
|
||||
return bool(parsed.netloc) and bool(parsed.scheme)
|
0
api/core/rag/index_processor/__init__.py
Normal file
0
api/core/rag/index_processor/__init__.py
Normal file
0
api/core/rag/index_processor/constant/__init__.py
Normal file
0
api/core/rag/index_processor/constant/__init__.py
Normal file
8
api/core/rag/index_processor/constant/index_type.py
Normal file
8
api/core/rag/index_processor/constant/index_type.py
Normal file
@ -0,0 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class IndexType(Enum):
|
||||
PARAGRAPH_INDEX = "text_model"
|
||||
QA_INDEX = "qa_model"
|
||||
PARENT_CHILD_INDEX = "parent_child_index"
|
||||
SUMMARY_INDEX = "summary_index"
|
70
api/core/rag/index_processor/index_processor_base.py
Normal file
70
api/core/rag/index_processor/index_processor_base.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.models.document import Document
|
||||
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
|
||||
|
||||
class BaseIndexProcessor(ABC):
|
||||
"""Interface for extract files.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||
raise NotImplementedError
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
score_threshold: float, reranking_model: dict) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_splitter(self, processing_rule: dict,
|
||||
embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
|
||||
"""
|
||||
Get the NodeParser object according to the processing rule.
|
||||
"""
|
||||
if processing_rule['mode'] == "custom":
|
||||
# The user-defined segmentation rule
|
||||
rules = processing_rule['rules']
|
||||
segmentation = rules["segmentation"]
|
||||
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
|
||||
raise ValueError("Custom segment length should be between 50 and 1000.")
|
||||
|
||||
separator = segmentation["separator"]
|
||||
if separator:
|
||||
separator = separator.replace('\\n', '\n')
|
||||
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=segmentation["max_tokens"],
|
||||
chunk_overlap=0,
|
||||
fixed_separator=separator,
|
||||
separators=["\n\n", "。", ".", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
)
|
||||
else:
|
||||
# Automatic segmentation
|
||||
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
|
||||
chunk_overlap=0,
|
||||
separators=["\n\n", "。", ".", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
)
|
||||
|
||||
return character_splitter
|
28
api/core/rag/index_processor/index_processor_factory.py
Normal file
28
api/core/rag/index_processor/index_processor_factory.py
Normal file
@ -0,0 +1,28 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
|
||||
|
||||
|
||||
class IndexProcessorFactory:
|
||||
"""IndexProcessorInit.
|
||||
"""
|
||||
|
||||
def __init__(self, index_type: str):
|
||||
self._index_type = index_type
|
||||
|
||||
def init_index_processor(self) -> BaseIndexProcessor:
|
||||
"""Init index processor."""
|
||||
|
||||
if not self._index_type:
|
||||
raise ValueError("Index type must be specified.")
|
||||
|
||||
if self._index_type == IndexType.PARAGRAPH_INDEX.value:
|
||||
return ParagraphIndexProcessor()
|
||||
elif self._index_type == IndexType.QA_INDEX.value:
|
||||
|
||||
return QAIndexProcessor()
|
||||
else:
|
||||
raise ValueError(f"Index type {self._index_type} is not supported.")
|
0
api/core/rag/index_processor/processor/__init__.py
Normal file
0
api/core/rag/index_processor/processor/__init__.py
Normal file
@ -0,0 +1,92 @@
|
||||
"""Paragraph index processor."""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
|
||||
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
|
||||
is_automatic=kwargs.get('process_rule_mode') == "automatic")
|
||||
|
||||
return text_docs
|
||||
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
# Split the text documents into nodes.
|
||||
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
|
||||
embedding_model_instance=kwargs.get('embedding_model_instance'))
|
||||
all_documents = []
|
||||
for document in documents:
|
||||
# document clean
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
|
||||
document.page_content = document_text
|
||||
# parse document to nodes
|
||||
document_nodes = splitter.split_documents([document])
|
||||
split_documents = []
|
||||
for document_node in document_nodes:
|
||||
|
||||
if document_node.page_content.strip():
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
# delete Spliter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
return all_documents
|
||||
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if with_keywords:
|
||||
keyword = Keyword(dataset)
|
||||
keyword.create(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
else:
|
||||
vector.delete()
|
||||
if with_keywords:
|
||||
keyword = Keyword(dataset)
|
||||
if node_ids:
|
||||
keyword.delete_by_ids(node_ids)
|
||||
else:
|
||||
keyword.delete()
|
||||
|
||||
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
score_threshold: float, reranking_model: dict) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata['score'] = result.score
|
||||
if result.score > score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
161
api/core/rag/index_processor/processor/qa_index_processor.py
Normal file
161
api/core/rag/index_processor/processor/qa_index_processor.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""Paragraph index processor."""
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
from flask_login import current_user
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class QAIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
|
||||
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
|
||||
is_automatic=kwargs.get('process_rule_mode') == "automatic")
|
||||
return text_docs
|
||||
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
|
||||
embedding_model_instance=None)
|
||||
|
||||
# Split the text documents into nodes.
|
||||
all_documents = []
|
||||
all_qa_documents = []
|
||||
for document in documents:
|
||||
# document clean
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
|
||||
document.page_content = document_text
|
||||
|
||||
# parse document to nodes
|
||||
document_nodes = splitter.split_documents([document])
|
||||
split_documents = []
|
||||
for document_node in document_nodes:
|
||||
|
||||
if document_node.page_content.strip():
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
# delete Spliter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
for i in range(0, len(all_documents), 10):
|
||||
threads = []
|
||||
sub_documents = all_documents[i:i + 10]
|
||||
for doc in sub_documents:
|
||||
document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'tenant_id': current_user.current_tenant.id,
|
||||
'document_node': doc,
|
||||
'all_qa_documents': all_qa_documents,
|
||||
'document_language': kwargs.get('document_language', 'English')})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
return all_qa_documents
|
||||
|
||||
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
|
||||
|
||||
# check file type
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file)
|
||||
text_docs = []
|
||||
for index, row in df.iterrows():
|
||||
data = Document(page_content=row[0], metadata={'answer': row[1]})
|
||||
text_docs.append(data)
|
||||
if len(text_docs) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
return text_docs
|
||||
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
else:
|
||||
vector.delete()
|
||||
|
||||
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
score_threshold: float, reranking_model: dict):
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata['score'] = result.score
|
||||
if result.score > score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
return
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
|
||||
document_qa_list = self._format_split_text(response)
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result['question'])
|
||||
qa_document.metadata['answer'] = result['answer']
|
||||
qa_document.metadata['doc_id'] = doc_id
|
||||
qa_document.metadata['doc_hash'] = hash
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
||||
def _format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
|
||||
matches = re.findall(regex, text, re.UNICODE)
|
||||
|
||||
return [
|
||||
{
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
}
|
||||
for q, a in matches if q and a
|
||||
]
|
0
api/core/rag/models/__init__.py
Normal file
0
api/core/rag/models/__init__.py
Normal file
16
api/core/rag/models/document.py
Normal file
16
api/core/rag/models/document.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
page_content: str
|
||||
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
metadata: Optional[dict] = Field(default_factory=dict)
|
||||
|
||||
|
@ -21,9 +21,9 @@ from pydantic import BaseModel, Field
|
||||
from regex import regex
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.data_loader import file_extractor
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.rag.extractor import extract_processor
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
@ -146,7 +146,7 @@ def get_url(url: str) -> str:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
|
||||
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
|
||||
|
||||
@ -158,8 +158,8 @@ def get_url(url: str) -> str:
|
||||
if main_content_type not in supported_content_types:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return FileExtractor.load_from_url(url, return_text=True)
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return ExtractProcessor.load_from_url(url, return_text=True)
|
||||
|
||||
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
|
||||
a = extract_using_readabilipy(response.text)
|
||||
|
@ -6,15 +6,12 @@ from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.retrieval_service import RetrievalService
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
@ -174,76 +171,24 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
kw_table_index = KeywordTableIndex(
|
||||
dataset=dataset,
|
||||
config=KeywordTableConfig(
|
||||
max_keywords_per_chunk=5
|
||||
)
|
||||
)
|
||||
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
return []
|
||||
except ProviderTokenNotInitError:
|
||||
return []
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
documents = []
|
||||
threads = []
|
||||
if self.top_k > 0:
|
||||
# retrieval_model source with semantic
|
||||
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[
|
||||
'search_method'] == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'top_k': self.top_k,
|
||||
'score_threshold': self.score_threshold,
|
||||
'reranking_model': None,
|
||||
'all_documents': documents,
|
||||
'search_method': 'hybrid_search',
|
||||
'embeddings': embeddings
|
||||
})
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval_model source with full text
|
||||
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[
|
||||
'search_method'] == 'hybrid_search':
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
|
||||
kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'search_method': 'hybrid_search',
|
||||
'embeddings': embeddings,
|
||||
'score_threshold': retrieval_model[
|
||||
'score_threshold'] if retrieval_model[
|
||||
'score_threshold_enabled'] else None,
|
||||
'top_k': self.top_k,
|
||||
'reranking_model': retrieval_model[
|
||||
'reranking_model'] if retrieval_model[
|
||||
'reranking_enable'] else None,
|
||||
'all_documents': documents
|
||||
})
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
@ -1,20 +1,12 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.retrieval_service import RetrievalService
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
@ -77,94 +69,24 @@ class DatasetRetrieverTool(BaseTool):
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
kw_table_index = KeywordTableIndex(
|
||||
dataset=dataset,
|
||||
config=KeywordTableConfig(
|
||||
max_keywords_per_chunk=5
|
||||
)
|
||||
)
|
||||
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k
|
||||
)
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
# get embedding model instance
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return ''
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
documents = []
|
||||
threads = []
|
||||
if self.top_k > 0:
|
||||
# retrieval source with semantic
|
||||
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'top_k': self.top_k,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
||||
'score_threshold_enabled'] else None,
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
|
||||
'reranking_enable'] else None,
|
||||
'all_documents': documents,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings
|
||||
})
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval_model source with full text
|
||||
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
||||
'score_threshold_enabled'] else None,
|
||||
'top_k': self.top_k,
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
|
||||
'reranking_enable'] else None,
|
||||
'all_documents': documents
|
||||
})
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# hybrid search: rerank after all documents have been searched
|
||||
if retrieval_model['search_method'] == 'hybrid_search':
|
||||
# get rerank model instance
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=retrieval_model['reranking_model']['reranking_model_name']
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return ''
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
documents = rerank_runner.run(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=retrieval_model['score_threshold'] if retrieval_model[
|
||||
'score_threshold_enabled'] else None,
|
||||
top_n=self.top_k
|
||||
)
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
|
||||
@ -234,4 +156,4 @@ class DatasetRetrieverTool(BaseTool):
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError()
|
||||
|
@ -21,9 +21,9 @@ from pydantic import BaseModel, Field
|
||||
from regex import regex
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.data_loader import file_extractor
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.rag.extractor import extract_processor
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
@ -149,7 +149,7 @@ def get_url(url: str, user_agent: str = None) -> str:
|
||||
if user_agent:
|
||||
headers["User-Agent"] = user_agent
|
||||
|
||||
supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
|
||||
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
|
||||
|
||||
@ -161,8 +161,8 @@ def get_url(url: str, user_agent: str = None) -> str:
|
||||
if main_content_type not in supported_content_types:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return FileExtractor.load_from_url(url, return_text=True)
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return ExtractProcessor.load_from_url(url, return_text=True)
|
||||
|
||||
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
|
||||
a = extract_using_readabilipy(response.text)
|
||||
|
@ -1,56 +0,0 @@
|
||||
from core.vector_store.vector.milvus import Milvus
|
||||
|
||||
|
||||
class MilvusVectorStore(Milvus):
|
||||
def del_texts(self, where_filter: dict):
|
||||
if not where_filter:
|
||||
raise ValueError('where_filter must not be empty')
|
||||
|
||||
self.col.delete(where_filter.get('filter'))
|
||||
|
||||
def del_text(self, uuid: str) -> None:
|
||||
expr = f"id == {uuid}"
|
||||
self.col.delete(expr)
|
||||
|
||||
def text_exists(self, uuid: str) -> bool:
|
||||
result = self.col.query(
|
||||
expr=f'metadata["doc_id"] == "{uuid}"',
|
||||
output_fields=["id"]
|
||||
)
|
||||
|
||||
return len(result) > 0
|
||||
|
||||
def get_ids_by_document_id(self, document_id: str):
|
||||
result = self.col.query(
|
||||
expr=f'metadata["document_id"] == "{document_id}"',
|
||||
output_fields=["id"]
|
||||
)
|
||||
if result:
|
||||
return [item["id"] for item in result]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = self.col.query(
|
||||
expr=f'metadata["{key}"] == "{value}"',
|
||||
output_fields=["id"]
|
||||
)
|
||||
if result:
|
||||
return [item["id"] for item in result]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_ids_by_doc_ids(self, doc_ids: list):
|
||||
result = self.col.query(
|
||||
expr=f'metadata["doc_id"] in {doc_ids}',
|
||||
output_fields=["id"]
|
||||
)
|
||||
if result:
|
||||
return [item["id"] for item in result]
|
||||
else:
|
||||
return None
|
||||
|
||||
def delete(self):
|
||||
from pymilvus import utility
|
||||
utility.drop_collection(self.collection_name, None, self.alias)
|
||||
|
@ -1,76 +0,0 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain.schema import Document
|
||||
from qdrant_client.http.models import Filter, FilterSelector, PointIdsList
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
|
||||
from core.vector_store.vector.qdrant import Qdrant
|
||||
|
||||
|
||||
class QdrantVectorStore(Qdrant):
|
||||
def del_texts(self, filter: Filter):
|
||||
if not filter:
|
||||
raise ValueError('filter must not be empty')
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
|
||||
def del_text(self, uuid: str) -> None:
|
||||
self._reload_if_needed()
|
||||
|
||||
self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=PointIdsList(
|
||||
points=[uuid],
|
||||
),
|
||||
)
|
||||
|
||||
def text_exists(self, uuid: str) -> bool:
|
||||
self._reload_if_needed()
|
||||
|
||||
response = self.client.retrieve(
|
||||
collection_name=self.collection_name,
|
||||
ids=[uuid]
|
||||
)
|
||||
|
||||
return len(response) > 0
|
||||
|
||||
def delete(self):
|
||||
self._reload_if_needed()
|
||||
|
||||
self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
def delete_group(self):
|
||||
self._reload_if_needed()
|
||||
|
||||
self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
cls,
|
||||
scored_point: Any,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
) -> Document:
|
||||
if scored_point.payload.get('doc_id'):
|
||||
return Document(
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
metadata={'doc_id': scored_point.id}
|
||||
)
|
||||
|
||||
return Document(
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||
)
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self.client, QdrantLocal):
|
||||
self.client = cast(QdrantLocal, self.client)
|
||||
self.client._load()
|
||||
|
@ -1,852 +0,0 @@
|
||||
"""Wrapper around the Milvus vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MILVUS_CONNECTION = {
|
||||
"host": "localhost",
|
||||
"port": "19530",
|
||||
"user": "",
|
||||
"password": "",
|
||||
"secure": False,
|
||||
}
|
||||
|
||||
|
||||
class Milvus(VectorStore):
|
||||
"""Initialize wrapper around the milvus vector database.
|
||||
|
||||
In order to use this you need to have `pymilvus` installed and a
|
||||
running Milvus
|
||||
|
||||
See the following documentation for how to run a Milvus instance:
|
||||
https://milvus.io/docs/install_standalone-docker.md
|
||||
|
||||
If looking for a hosted Milvus, take a look at this documentation:
|
||||
https://zilliz.com/cloud and make use of the Zilliz vectorstore found in
|
||||
this project,
|
||||
|
||||
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
|
||||
|
||||
Args:
|
||||
embedding_function (Embeddings): Function used to embed the text.
|
||||
collection_name (str): Which Milvus collection to use. Defaults to
|
||||
"LangChainCollection".
|
||||
connection_args (Optional[dict[str, any]]): The connection args used for
|
||||
this class comes in the form of a dict.
|
||||
consistency_level (str): The consistency level to use for a collection.
|
||||
Defaults to "Session".
|
||||
index_params (Optional[dict]): Which index params to use. Defaults to
|
||||
HNSW/AUTOINDEX depending on service.
|
||||
search_params (Optional[dict]): Which search params to use. Defaults to
|
||||
default of index.
|
||||
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
|
||||
to False.
|
||||
|
||||
The connection args used for this class comes in the form of a dict,
|
||||
here are a few of the options:
|
||||
address (str): The actual address of Milvus
|
||||
instance. Example address: "localhost:19530"
|
||||
uri (str): The uri of Milvus instance. Example uri:
|
||||
"http://randomwebsite:19530",
|
||||
"tcp:foobarsite:19530",
|
||||
"https://ok.s3.south.com:19530".
|
||||
host (str): The host of Milvus instance. Default at "localhost",
|
||||
PyMilvus will fill in the default host if only port is provided.
|
||||
port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
|
||||
will fill in the default port if only host is provided.
|
||||
user (str): Use which user to connect to Milvus instance. If user and
|
||||
password are provided, we will add related header in every RPC call.
|
||||
password (str): Required when user is provided. The password
|
||||
corresponding to the user.
|
||||
secure (bool): Default is false. If set to true, tls will be enabled.
|
||||
client_key_path (str): If use tls two-way authentication, need to
|
||||
write the client.key path.
|
||||
client_pem_path (str): If use tls two-way authentication, need to
|
||||
write the client.pem path.
|
||||
ca_pem_path (str): If use tls two-way authentication, need to write
|
||||
the ca.pem path.
|
||||
server_pem_path (str): If use tls one-way authentication, need to
|
||||
write the server.pem path.
|
||||
server_name (str): If use tls, need to write the common name.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import Milvus
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embedding = OpenAIEmbeddings()
|
||||
# Connect to a milvus instance on localhost
|
||||
milvus_store = Milvus(
|
||||
embedding_function = Embeddings,
|
||||
collection_name = "LangChainCollection",
|
||||
drop_old = True,
|
||||
)
|
||||
|
||||
Raises:
|
||||
ValueError: If the pymilvus python package is not installed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Embeddings,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: Optional[dict[str, Any]] = None,
|
||||
consistency_level: str = "Session",
|
||||
index_params: Optional[dict] = None,
|
||||
search_params: Optional[dict] = None,
|
||||
drop_old: Optional[bool] = False,
|
||||
):
|
||||
"""Initialize the Milvus vector store."""
|
||||
try:
|
||||
from pymilvus import Collection, utility
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
"Please install it with `pip install pymilvus`."
|
||||
)
|
||||
|
||||
# Default search params when one is not provided.
|
||||
self.default_search_params = {
|
||||
"IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"HNSW": {"metric_type": "L2", "params": {"ef": 10}},
|
||||
"RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
|
||||
"RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
|
||||
"RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
|
||||
"IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
|
||||
"ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
|
||||
"AUTOINDEX": {"metric_type": "L2", "params": {}},
|
||||
}
|
||||
|
||||
self.embedding_func = embedding_function
|
||||
self.collection_name = collection_name
|
||||
self.index_params = index_params
|
||||
self.search_params = search_params
|
||||
self.consistency_level = consistency_level
|
||||
|
||||
# In order for a collection to be compatible, pk needs to be auto'id and int
|
||||
self._primary_field = "id"
|
||||
# In order for compatibility, the text field will need to be called "text"
|
||||
self._text_field = "page_content"
|
||||
# In order for compatibility, the vector field needs to be called "vector"
|
||||
self._vector_field = "vectors"
|
||||
# In order for compatibility, the metadata field will need to be called "metadata"
|
||||
self._metadata_field = "metadata"
|
||||
self.fields: list[str] = []
|
||||
# Create the connection to the server
|
||||
if connection_args is None:
|
||||
connection_args = DEFAULT_MILVUS_CONNECTION
|
||||
self.alias = self._create_connection_alias(connection_args)
|
||||
self.col: Optional[Collection] = None
|
||||
|
||||
# Grab the existing collection if it exists
|
||||
if utility.has_collection(self.collection_name, using=self.alias):
|
||||
self.col = Collection(
|
||||
self.collection_name,
|
||||
using=self.alias,
|
||||
)
|
||||
# If need to drop old, drop it
|
||||
if drop_old and isinstance(self.col, Collection):
|
||||
self.col.drop()
|
||||
self.col = None
|
||||
|
||||
# Initialize the vector store
|
||||
self._init()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_func
|
||||
|
||||
def _create_connection_alias(self, connection_args: dict) -> str:
|
||||
"""Create the connection to the Milvus server."""
|
||||
from pymilvus import MilvusException, connections
|
||||
|
||||
# Grab the connection arguments that are used for checking existing connection
|
||||
host: str = connection_args.get("host", None)
|
||||
port: Union[str, int] = connection_args.get("port", None)
|
||||
address: str = connection_args.get("address", None)
|
||||
uri: str = connection_args.get("uri", None)
|
||||
user = connection_args.get("user", None)
|
||||
|
||||
# Order of use is host/port, uri, address
|
||||
if host is not None and port is not None:
|
||||
given_address = str(host) + ":" + str(port)
|
||||
elif uri is not None:
|
||||
given_address = uri.split("https://")[1]
|
||||
elif address is not None:
|
||||
given_address = address
|
||||
else:
|
||||
given_address = None
|
||||
logger.debug("Missing standard address type for reuse atttempt")
|
||||
|
||||
# User defaults to empty string when getting connection info
|
||||
if user is not None:
|
||||
tmp_user = user
|
||||
else:
|
||||
tmp_user = ""
|
||||
|
||||
# If a valid address was given, then check if a connection exists
|
||||
if given_address is not None:
|
||||
for con in connections.list_connections():
|
||||
addr = connections.get_connection_addr(con[0])
|
||||
if (
|
||||
con[1]
|
||||
and ("address" in addr)
|
||||
and (addr["address"] == given_address)
|
||||
and ("user" in addr)
|
||||
and (addr["user"] == tmp_user)
|
||||
):
|
||||
logger.debug("Using previous connection: %s", con[0])
|
||||
return con[0]
|
||||
|
||||
# Generate a new connection if one doesn't exist
|
||||
alias = uuid4().hex
|
||||
try:
|
||||
connections.connect(alias=alias, **connection_args)
|
||||
logger.debug("Created new connection using: %s", alias)
|
||||
return alias
|
||||
except MilvusException as e:
|
||||
logger.error("Failed to create new connection using: %s", alias)
|
||||
raise e
|
||||
|
||||
def _init(
|
||||
self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
|
||||
) -> None:
|
||||
if embeddings is not None:
|
||||
self._create_collection(embeddings, metadatas)
|
||||
self._extract_fields()
|
||||
self._create_index()
|
||||
self._create_search_params()
|
||||
self._load()
|
||||
|
||||
def _create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None
|
||||
) -> None:
|
||||
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException
|
||||
from pymilvus.orm.types import infer_dtype_bydata
|
||||
|
||||
# Determine embedding dim
|
||||
dim = len(embeddings[0])
|
||||
fields = []
|
||||
# Determine metadata schema
|
||||
# if metadatas:
|
||||
# # Create FieldSchema for each entry in metadata.
|
||||
# for key, value in metadatas[0].items():
|
||||
# # Infer the corresponding datatype of the metadata
|
||||
# dtype = infer_dtype_bydata(value)
|
||||
# # Datatype isn't compatible
|
||||
# if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
|
||||
# logger.error(
|
||||
# "Failure to create collection, unrecognized dtype for key: %s",
|
||||
# key,
|
||||
# )
|
||||
# raise ValueError(f"Unrecognized datatype for {key}.")
|
||||
# # Dataype is a string/varchar equivalent
|
||||
# elif dtype == DataType.VARCHAR:
|
||||
# fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
|
||||
# else:
|
||||
# fields.append(FieldSchema(key, dtype))
|
||||
if metadatas:
|
||||
fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field
|
||||
fields.append(
|
||||
FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
|
||||
)
|
||||
# Create the primary key field
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
self._primary_field, DataType.INT64, is_primary=True, auto_id=True
|
||||
)
|
||||
)
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(
|
||||
FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
|
||||
)
|
||||
|
||||
# Create the schema for the collection
|
||||
schema = CollectionSchema(fields)
|
||||
|
||||
# Create the collection
|
||||
try:
|
||||
self.col = Collection(
|
||||
name=self.collection_name,
|
||||
schema=schema,
|
||||
consistency_level=self.consistency_level,
|
||||
using=self.alias,
|
||||
)
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to create collection: %s error: %s", self.collection_name, e
|
||||
)
|
||||
raise e
|
||||
|
||||
def _extract_fields(self) -> None:
|
||||
"""Grab the existing fields from the Collection"""
|
||||
from pymilvus import Collection
|
||||
|
||||
if isinstance(self.col, Collection):
|
||||
schema = self.col.schema
|
||||
for x in schema.fields:
|
||||
self.fields.append(x.name)
|
||||
# Since primary field is auto-id, no need to track it
|
||||
self.fields.remove(self._primary_field)
|
||||
|
||||
def _get_index(self) -> Optional[dict[str, Any]]:
|
||||
"""Return the vector index information if it exists"""
|
||||
from pymilvus import Collection
|
||||
|
||||
if isinstance(self.col, Collection):
|
||||
for x in self.col.indexes:
|
||||
if x.field_name == self._vector_field:
|
||||
return x.to_dict()
|
||||
return None
|
||||
|
||||
def _create_index(self) -> None:
|
||||
"""Create a index on the collection"""
|
||||
from pymilvus import Collection, MilvusException
|
||||
|
||||
if isinstance(self.col, Collection) and self._get_index() is None:
|
||||
try:
|
||||
# If no index params, use a default HNSW based one
|
||||
if self.index_params is None:
|
||||
self.index_params = {
|
||||
"metric_type": "IP",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 8, "efConstruction": 64},
|
||||
}
|
||||
|
||||
try:
|
||||
self.col.create_index(
|
||||
self._vector_field,
|
||||
index_params=self.index_params,
|
||||
using=self.alias,
|
||||
)
|
||||
|
||||
# If default did not work, most likely on Zilliz Cloud
|
||||
except MilvusException:
|
||||
# Use AUTOINDEX based index
|
||||
self.index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "AUTOINDEX",
|
||||
"params": {},
|
||||
}
|
||||
self.col.create_index(
|
||||
self._vector_field,
|
||||
index_params=self.index_params,
|
||||
using=self.alias,
|
||||
)
|
||||
logger.debug(
|
||||
"Successfully created an index on collection: %s",
|
||||
self.collection_name,
|
||||
)
|
||||
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to create an index on collection: %s", self.collection_name
|
||||
)
|
||||
raise e
|
||||
|
||||
def _create_search_params(self) -> None:
|
||||
"""Generate search params based on the current index type"""
|
||||
from pymilvus import Collection
|
||||
|
||||
if isinstance(self.col, Collection) and self.search_params is None:
|
||||
index = self._get_index()
|
||||
if index is not None:
|
||||
index_type: str = index["index_param"]["index_type"]
|
||||
metric_type: str = index["index_param"]["metric_type"]
|
||||
self.search_params = self.default_search_params[index_type]
|
||||
self.search_params["metric_type"] = metric_type
|
||||
|
||||
def _load(self) -> None:
|
||||
"""Load the collection if available."""
|
||||
from pymilvus import Collection
|
||||
|
||||
if isinstance(self.col, Collection) and self._get_index() is not None:
|
||||
self.col.load()
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
batch_size: int = 1000,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Insert text data into Milvus.
|
||||
|
||||
Inserting data when the collection has not be made yet will result
|
||||
in creating a new Collection. The data of the first entity decides
|
||||
the schema of the new collection, the dim is extracted from the first
|
||||
embedding and the columns are decided by the first metadata dict.
|
||||
Metada keys will need to be present for all inserted values. At
|
||||
the moment there is no None equivalent in Milvus.
|
||||
|
||||
Args:
|
||||
texts (Iterable[str]): The texts to embed, it is assumed
|
||||
that they all fit in memory.
|
||||
metadatas (Optional[List[dict]]): Metadata dicts attached to each of
|
||||
the texts. Defaults to None.
|
||||
timeout (Optional[int]): Timeout for each batch insert. Defaults
|
||||
to None.
|
||||
batch_size (int, optional): Batch size to use for insertion.
|
||||
Defaults to 1000.
|
||||
|
||||
Raises:
|
||||
MilvusException: Failure to add texts
|
||||
|
||||
Returns:
|
||||
List[str]: The resulting keys for each inserted element.
|
||||
"""
|
||||
from pymilvus import Collection, MilvusException
|
||||
|
||||
texts = list(texts)
|
||||
|
||||
try:
|
||||
embeddings = self.embedding_func.embed_documents(texts)
|
||||
except NotImplementedError:
|
||||
embeddings = [self.embedding_func.embed_query(x) for x in texts]
|
||||
|
||||
if len(embeddings) == 0:
|
||||
logger.debug("Nothing to insert, skipping.")
|
||||
return []
|
||||
|
||||
# If the collection hasn't been initialized yet, perform all steps to do so
|
||||
if not isinstance(self.col, Collection):
|
||||
self._init(embeddings, metadatas)
|
||||
|
||||
# Dict to hold all insert columns
|
||||
insert_dict: dict[str, list] = {
|
||||
self._text_field: texts,
|
||||
self._vector_field: embeddings,
|
||||
}
|
||||
|
||||
# Collect the metadata into the insert dict.
|
||||
# if metadatas is not None:
|
||||
# for d in metadatas:
|
||||
# for key, value in d.items():
|
||||
# if key in self.fields:
|
||||
# insert_dict.setdefault(key, []).append(value)
|
||||
if metadatas is not None:
|
||||
for d in metadatas:
|
||||
insert_dict.setdefault(self._metadata_field, []).append(d)
|
||||
|
||||
# Total insert count
|
||||
vectors: list = insert_dict[self._vector_field]
|
||||
total_count = len(vectors)
|
||||
|
||||
pks: list[str] = []
|
||||
|
||||
assert isinstance(self.col, Collection)
|
||||
for i in range(0, total_count, batch_size):
|
||||
# Grab end index
|
||||
end = min(i + batch_size, total_count)
|
||||
# Convert dict to list of lists batch for insertion
|
||||
insert_list = [insert_dict[x][i:end] for x in self.fields]
|
||||
# Insert into the collection.
|
||||
try:
|
||||
res: Collection
|
||||
res = self.col.insert(insert_list, timeout=timeout, **kwargs)
|
||||
pks.extend(res.primary_keys)
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to insert batch starting at entity: %s/%s", i, total_count
|
||||
)
|
||||
raise e
|
||||
return pks
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
"""Perform a similarity search against the query string.
|
||||
|
||||
Args:
|
||||
query (str): The text to search.
|
||||
k (int, optional): How many results to return. Defaults to 4.
|
||||
param (dict, optional): The search params for the index type.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Document]: Document results for search.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
res = self.similarity_search_with_score(
|
||||
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in res]
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
"""Perform a similarity search against the query string.
|
||||
|
||||
Args:
|
||||
embedding (List[float]): The embedding vector to search.
|
||||
k (int, optional): How many results to return. Defaults to 4.
|
||||
param (dict, optional): The search params for the index type.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Document]: Document results for search.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
res = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in res]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""Perform a search on a query string and return results with score.
|
||||
|
||||
For more information about the search parameters, take a look at the pymilvus
|
||||
documentation found here:
|
||||
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
|
||||
|
||||
Args:
|
||||
query (str): The text being searched.
|
||||
k (int, optional): The amount of results to return. Defaults to 4.
|
||||
param (dict): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[float], List[Tuple[Document, any, any]]:
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
# Embed the query text.
|
||||
embedding = self.embedding_func.embed_query(query)
|
||||
|
||||
res = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||
)
|
||||
return res
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
|
||||
Args:
|
||||
query: input text
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of Tuples of (doc, similarity_score)
|
||||
"""
|
||||
return self.similarity_search_with_score(query, k, **kwargs)
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""Perform a search on a query string and return results with score.
|
||||
|
||||
For more information about the search parameters, take a look at the pymilvus
|
||||
documentation found here:
|
||||
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
|
||||
|
||||
Args:
|
||||
embedding (List[float]): The embedding vector being searched.
|
||||
k (int, optional): The amount of results to return. Defaults to 4.
|
||||
param (dict): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Document, float]]: Result doc and score.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
if param is None:
|
||||
param = self.search_params
|
||||
|
||||
# Determine result metadata fields.
|
||||
output_fields = self.fields[:]
|
||||
output_fields.remove(self._vector_field)
|
||||
|
||||
# Perform the search.
|
||||
res = self.col.search(
|
||||
data=[embedding],
|
||||
anns_field=self._vector_field,
|
||||
param=param,
|
||||
limit=k,
|
||||
expr=expr,
|
||||
output_fields=output_fields,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
# Organize results.
|
||||
ret = []
|
||||
for result in res[0]:
|
||||
meta = {x: result.entity.get(x) for x in output_fields}
|
||||
doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata'))
|
||||
pair = (doc, result.score)
|
||||
ret.append(pair)
|
||||
|
||||
return ret
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
"""Perform a search and return results that are reordered by MMR.
|
||||
|
||||
Args:
|
||||
query (str): The text being searched.
|
||||
k (int, optional): How many results to give. Defaults to 4.
|
||||
fetch_k (int, optional): Total results to select k from.
|
||||
Defaults to 20.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5
|
||||
param (dict, optional): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
|
||||
Returns:
|
||||
List[Document]: Document results for search.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
embedding = self.embedding_func.embed_query(query)
|
||||
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
param=param,
|
||||
expr=expr,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
"""Perform a search and return results that are reordered by MMR.
|
||||
|
||||
Args:
|
||||
embedding (str): The embedding vector being searched.
|
||||
k (int, optional): How many results to give. Defaults to 4.
|
||||
fetch_k (int, optional): Total results to select k from.
|
||||
Defaults to 20.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5
|
||||
param (dict, optional): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Document]: Document results for search.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
if param is None:
|
||||
param = self.search_params
|
||||
|
||||
# Determine result metadata fields.
|
||||
output_fields = self.fields[:]
|
||||
output_fields.remove(self._vector_field)
|
||||
|
||||
# Perform the search.
|
||||
res = self.col.search(
|
||||
data=[embedding],
|
||||
anns_field=self._vector_field,
|
||||
param=param,
|
||||
limit=fetch_k,
|
||||
expr=expr,
|
||||
output_fields=output_fields,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
# Organize results.
|
||||
ids = []
|
||||
documents = []
|
||||
scores = []
|
||||
for result in res[0]:
|
||||
meta = {x: result.entity.get(x) for x in output_fields}
|
||||
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
|
||||
documents.append(doc)
|
||||
scores.append(result.score)
|
||||
ids.append(result.id)
|
||||
|
||||
vectors = self.col.query(
|
||||
expr=f"{self._primary_field} in {ids}",
|
||||
output_fields=[self._primary_field, self._vector_field],
|
||||
timeout=timeout,
|
||||
)
|
||||
# Reorganize the results from query to match search order.
|
||||
vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
|
||||
|
||||
ordered_result_embeddings = [vectors[x] for x in ids]
|
||||
|
||||
# Get the new order of results.
|
||||
new_ordering = maximal_marginal_relevance(
|
||||
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
|
||||
)
|
||||
|
||||
# Reorder the values and return.
|
||||
ret = []
|
||||
for x in new_ordering:
|
||||
# Function can return -1 index
|
||||
if x == -1:
|
||||
break
|
||||
else:
|
||||
ret.append(documents[x])
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: list[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
|
||||
consistency_level: str = "Session",
|
||||
index_params: Optional[dict] = None,
|
||||
search_params: Optional[dict] = None,
|
||||
drop_old: bool = False,
|
||||
batch_size: int = 100,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Milvus:
|
||||
"""Create a Milvus collection, indexes it with HNSW, and insert data.
|
||||
|
||||
Args:
|
||||
texts (List[str]): Text data.
|
||||
embedding (Embeddings): Embedding function.
|
||||
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
|
||||
Defaults to None.
|
||||
collection_name (str, optional): Collection name to use. Defaults to
|
||||
"LangChainCollection".
|
||||
connection_args (dict[str, Any], optional): Connection args to use. Defaults
|
||||
to DEFAULT_MILVUS_CONNECTION.
|
||||
consistency_level (str, optional): Which consistency level to use. Defaults
|
||||
to "Session".
|
||||
index_params (Optional[dict], optional): Which index_params to use. Defaults
|
||||
to None.
|
||||
search_params (Optional[dict], optional): Which search params to use.
|
||||
Defaults to None.
|
||||
drop_old (Optional[bool], optional): Whether to drop the collection with
|
||||
that name if it exists. Defaults to False.
|
||||
batch_size:
|
||||
How many vectors upload per-request.
|
||||
Default: 100
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
|
||||
Returns:
|
||||
Milvus: Milvus Vector Store
|
||||
"""
|
||||
vector_db = cls(
|
||||
embedding_function=embedding,
|
||||
collection_name=collection_name,
|
||||
connection_args=connection_args,
|
||||
consistency_level=consistency_level,
|
||||
index_params=index_params,
|
||||
search_params=search_params,
|
||||
drop_old=drop_old,
|
||||
**kwargs,
|
||||
)
|
||||
vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size)
|
||||
return vector_db
|
File diff suppressed because it is too large
Load Diff
@ -1,506 +0,0 @@
|
||||
"""Wrapper around weaviate vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
|
||||
def _default_schema(index_name: str) -> dict:
|
||||
return {
|
||||
"class": index_name,
|
||||
"properties": [
|
||||
{
|
||||
"name": "text",
|
||||
"dataType": ["text"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _create_weaviate_client(**kwargs: Any) -> Any:
|
||||
client = kwargs.get("client")
|
||||
if client is not None:
|
||||
return client
|
||||
|
||||
weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
|
||||
|
||||
try:
|
||||
# the weaviate api key param should not be mandatory
|
||||
weaviate_api_key = get_from_dict_or_env(
|
||||
kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
|
||||
)
|
||||
except ValueError:
|
||||
weaviate_api_key = None
|
||||
|
||||
try:
|
||||
import weaviate
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import weaviate python package. "
|
||||
"Please install it with `pip install weaviate-client`"
|
||||
)
|
||||
|
||||
auth = (
|
||||
weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
|
||||
if weaviate_api_key is not None
|
||||
else None
|
||||
)
|
||||
client = weaviate.Client(weaviate_url, auth_client_secret=auth)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def _default_score_normalizer(val: float) -> float:
|
||||
return 1 - val
|
||||
|
||||
|
||||
def _json_serializable(value: Any) -> Any:
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value.isoformat()
|
||||
return value
|
||||
|
||||
|
||||
class Weaviate(VectorStore):
|
||||
"""Wrapper around Weaviate vector database.
|
||||
|
||||
To use, you should have the ``weaviate-client`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
import weaviate
|
||||
from langchain.vectorstores import Weaviate
|
||||
client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
|
||||
weaviate = Weaviate(client, index_name, text_key)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Any,
|
||||
index_name: str,
|
||||
text_key: str,
|
||||
embedding: Optional[Embeddings] = None,
|
||||
attributes: Optional[list[str]] = None,
|
||||
relevance_score_fn: Optional[
|
||||
Callable[[float], float]
|
||||
] = _default_score_normalizer,
|
||||
by_text: bool = True,
|
||||
):
|
||||
"""Initialize with Weaviate client."""
|
||||
try:
|
||||
import weaviate
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import weaviate python package. "
|
||||
"Please install it with `pip install weaviate-client`."
|
||||
)
|
||||
if not isinstance(client, weaviate.Client):
|
||||
raise ValueError(
|
||||
f"client should be an instance of weaviate.Client, got {type(client)}"
|
||||
)
|
||||
self._client = client
|
||||
self._index_name = index_name
|
||||
self._embedding = embedding
|
||||
self._text_key = text_key
|
||||
self._query_attrs = [self._text_key]
|
||||
self.relevance_score_fn = relevance_score_fn
|
||||
self._by_text = by_text
|
||||
if attributes is not None:
|
||||
self._query_attrs.extend(attributes)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return self._embedding
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return (
|
||||
self.relevance_score_fn
|
||||
if self.relevance_score_fn
|
||||
else _default_score_normalizer
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Upload texts with metadata (properties) to Weaviate."""
|
||||
from weaviate.util import get_valid_uuid
|
||||
|
||||
ids = []
|
||||
embeddings: Optional[list[list[float]]] = None
|
||||
if self._embedding:
|
||||
if not isinstance(texts, list):
|
||||
texts = list(texts)
|
||||
embeddings = self._embedding.embed_documents(texts)
|
||||
|
||||
with self._client.batch as batch:
|
||||
for i, text in enumerate(texts):
|
||||
data_properties = {self._text_key: text}
|
||||
if metadatas is not None:
|
||||
for key, val in metadatas[i].items():
|
||||
data_properties[key] = _json_serializable(val)
|
||||
|
||||
# Allow for ids (consistent w/ other methods)
|
||||
# # Or uuids (backwards compatble w/ existing arg)
|
||||
# If the UUID of one of the objects already exists
|
||||
# then the existing object will be replaced by the new object.
|
||||
_id = get_valid_uuid(uuid4())
|
||||
if "uuids" in kwargs:
|
||||
_id = kwargs["uuids"][i]
|
||||
elif "ids" in kwargs:
|
||||
_id = kwargs["ids"][i]
|
||||
|
||||
batch.add_data_object(
|
||||
data_object=data_properties,
|
||||
class_name=self._index_name,
|
||||
uuid=_id,
|
||||
vector=embeddings[i] if embeddings else None,
|
||||
)
|
||||
ids.append(_id)
|
||||
return ids
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
if self._by_text:
|
||||
return self.similarity_search_by_text(query, k, **kwargs)
|
||||
else:
|
||||
if self._embedding is None:
|
||||
raise ValueError(
|
||||
"_embedding cannot be None for similarity_search when "
|
||||
"_by_text=False"
|
||||
)
|
||||
embedding = self._embedding.embed_query(query)
|
||||
return self.similarity_search_by_vector(embedding, k, **kwargs)
|
||||
|
||||
def similarity_search_by_text(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
content: dict[str, Any] = {"concepts": [query]}
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
result = query_obj.with_near_text(content).with_limit(k).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][self._index_name]:
|
||||
text = res.pop(self._text_key)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
||||
|
||||
def similarity_search_by_bm25(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
"""Return docs using BM25F.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
content: dict[str, Any] = {"concepts": [query]}
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
properties = ['text']
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(k).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][self._index_name]:
|
||||
text = res.pop(self._text_key)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self, embedding: list[float], k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
vector = {"vector": embedding}
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
result = query_obj.with_near_vector(vector).with_limit(k).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][self._index_name]:
|
||||
text = res.pop(self._text_key)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
if self._embedding is not None:
|
||||
embedding = self._embedding.embed_query(query)
|
||||
else:
|
||||
raise ValueError(
|
||||
"max_marginal_relevance_search requires a suitable Embeddings object"
|
||||
)
|
||||
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
vector = {"vector": embedding}
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
results = (
|
||||
query_obj.with_additional("vector")
|
||||
.with_near_vector(vector)
|
||||
.with_limit(fetch_k)
|
||||
.do()
|
||||
)
|
||||
|
||||
payload = results["data"]["Get"][self._index_name]
|
||||
embeddings = [result["_additional"]["vector"] for result in payload]
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
||||
)
|
||||
|
||||
docs = []
|
||||
for idx in mmr_selected:
|
||||
text = payload[idx].pop(self._text_key)
|
||||
payload[idx].pop("_additional")
|
||||
meta = payload[idx]
|
||||
docs.append(Document(page_content=text, metadata=meta))
|
||||
return docs
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""
|
||||
Return list of documents most similar to the query
|
||||
text and cosine distance in float for each.
|
||||
Lower score represents more similarity.
|
||||
"""
|
||||
if self._embedding is None:
|
||||
raise ValueError(
|
||||
"_embedding cannot be None for similarity_search_with_score"
|
||||
)
|
||||
content: dict[str, Any] = {"concepts": [query]}
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
|
||||
embedded_query = self._embedding.embed_query(query)
|
||||
if not self._by_text:
|
||||
vector = {"vector": embedded_query}
|
||||
result = (
|
||||
query_obj.with_near_vector(vector)
|
||||
.with_limit(k)
|
||||
.with_additional(["vector", "distance"])
|
||||
.do()
|
||||
)
|
||||
else:
|
||||
result = (
|
||||
query_obj.with_near_text(content)
|
||||
.with_limit(k)
|
||||
.with_additional(["vector", "distance"])
|
||||
.do()
|
||||
)
|
||||
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
|
||||
docs_and_scores = []
|
||||
for res in result["data"]["Get"][self._index_name]:
|
||||
text = res.pop(self._text_key)
|
||||
score = res["_additional"]["distance"]
|
||||
docs_and_scores.append((Document(page_content=text, metadata=res), score))
|
||||
return docs_and_scores
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: type[Weaviate],
|
||||
texts: list[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Weaviate:
|
||||
"""Construct Weaviate wrapper from raw documents.
|
||||
|
||||
This is a user-friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates a new index for the embeddings in the Weaviate instance.
|
||||
3. Adds the documents to the newly created Weaviate index.
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores.weaviate import Weaviate
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
weaviate = Weaviate.from_texts(
|
||||
texts,
|
||||
embeddings,
|
||||
weaviate_url="http://localhost:8080"
|
||||
)
|
||||
"""
|
||||
|
||||
client = _create_weaviate_client(**kwargs)
|
||||
|
||||
from weaviate.util import get_valid_uuid
|
||||
|
||||
index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
|
||||
embeddings = embedding.embed_documents(texts) if embedding else None
|
||||
text_key = "text"
|
||||
schema = _default_schema(index_name)
|
||||
attributes = list(metadatas[0].keys()) if metadatas else None
|
||||
|
||||
# check whether the index already exists
|
||||
if not client.schema.contains(schema):
|
||||
client.schema.create_class(schema)
|
||||
|
||||
with client.batch as batch:
|
||||
for i, text in enumerate(texts):
|
||||
data_properties = {
|
||||
text_key: text,
|
||||
}
|
||||
if metadatas is not None:
|
||||
for key in metadatas[i].keys():
|
||||
data_properties[key] = metadatas[i][key]
|
||||
|
||||
# If the UUID of one of the objects already exists
|
||||
# then the existing objectwill be replaced by the new object.
|
||||
if "uuids" in kwargs:
|
||||
_id = kwargs["uuids"][i]
|
||||
else:
|
||||
_id = get_valid_uuid(uuid4())
|
||||
|
||||
# if an embedding strategy is not provided, we let
|
||||
# weaviate create the embedding. Note that this will only
|
||||
# work if weaviate has been installed with a vectorizer module
|
||||
# like text2vec-contextionary for example
|
||||
params = {
|
||||
"uuid": _id,
|
||||
"data_object": data_properties,
|
||||
"class_name": index_name,
|
||||
}
|
||||
if embeddings is not None:
|
||||
params["vector"] = embeddings[i]
|
||||
|
||||
batch.add_data_object(**params)
|
||||
|
||||
batch.flush()
|
||||
|
||||
relevance_score_fn = kwargs.get("relevance_score_fn")
|
||||
by_text: bool = kwargs.get("by_text", False)
|
||||
|
||||
return cls(
|
||||
client,
|
||||
index_name,
|
||||
text_key,
|
||||
embedding=embedding,
|
||||
attributes=attributes,
|
||||
relevance_score_fn=relevance_score_fn,
|
||||
by_text=by_text,
|
||||
)
|
||||
|
||||
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> None:
|
||||
"""Delete by vector IDs.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
"""
|
||||
|
||||
if ids is None:
|
||||
raise ValueError("No ids provided to delete.")
|
||||
|
||||
# TODO: Check if this can be done in bulk
|
||||
for id in ids:
|
||||
self._client.data_object.delete(uuid=id)
|
@ -1,38 +0,0 @@
|
||||
from core.vector_store.vector.weaviate import Weaviate
|
||||
|
||||
|
||||
class WeaviateVectorStore(Weaviate):
|
||||
def del_texts(self, where_filter: dict):
|
||||
if not where_filter:
|
||||
raise ValueError('where_filter must not be empty')
|
||||
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._index_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
|
||||
def del_text(self, uuid: str) -> None:
|
||||
self._client.data_object.delete(
|
||||
uuid,
|
||||
class_name=self._index_name
|
||||
)
|
||||
|
||||
def text_exists(self, uuid: str) -> bool:
|
||||
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": uuid,
|
||||
}).with_limit(1).do()
|
||||
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
|
||||
entries = result["data"]["Get"][self._index_name]
|
||||
if len(entries) == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def delete(self):
|
||||
self._client.schema.delete_class(self._index_name)
|
@ -6,4 +6,4 @@ from tasks.clean_dataset_task import clean_dataset_task
|
||||
def handle(sender, **kwargs):
|
||||
dataset = sender
|
||||
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
|
||||
dataset.index_struct, dataset.collection_binding_id)
|
||||
dataset.index_struct, dataset.collection_binding_id, dataset.doc_form)
|
||||
|
@ -6,4 +6,5 @@ from tasks.clean_document_task import clean_document_task
|
||||
def handle(sender, **kwargs):
|
||||
document_id = sender
|
||||
dataset_id = kwargs.get('dataset_id')
|
||||
clean_document_task.delay(document_id, dataset_id)
|
||||
doc_form = kwargs.get('doc_form')
|
||||
clean_document_task.delay(document_id, dataset_id, doc_form)
|
||||
|
@ -94,6 +94,14 @@ class Dataset(db.Model):
|
||||
return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
|
||||
.filter(Document.dataset_id == self.id).scalar()
|
||||
|
||||
@property
|
||||
def doc_form(self):
|
||||
document = db.session.query(Document).filter(
|
||||
Document.dataset_id == self.id).first()
|
||||
if document:
|
||||
return document.doc_form
|
||||
return None
|
||||
|
||||
@property
|
||||
def retrieval_model_dict(self):
|
||||
default_retrieval_model = {
|
||||
|
@ -6,7 +6,7 @@ from flask import current_app
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import app
|
||||
from core.index.index import IndexBuilder
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetQuery, Document
|
||||
|
||||
@ -41,18 +41,9 @@ def clean_unused_datasets_task():
|
||||
if not documents or len(documents) == 0:
|
||||
try:
|
||||
# remove index
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
if dataset.collection_binding_id:
|
||||
vector_index.delete_by_group_id(dataset.id)
|
||||
else:
|
||||
if dataset.collection_binding_id:
|
||||
vector_index.delete_by_group_id(dataset.id)
|
||||
else:
|
||||
vector_index.delete()
|
||||
kw_index.delete()
|
||||
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None)
|
||||
|
||||
# update document
|
||||
update_params = {
|
||||
Document.enabled: False
|
||||
|
@ -11,10 +11,11 @@ from flask_login import current_user
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.models.document import Document as RAGDocument
|
||||
from events.dataset_event import dataset_was_deleted
|
||||
from events.document_event import document_was_deleted
|
||||
from extensions.ext_database import db
|
||||
@ -402,7 +403,7 @@ class DocumentService:
|
||||
@staticmethod
|
||||
def delete_document(document):
|
||||
# trigger document_was_deleted signal
|
||||
document_was_deleted.send(document.id, dataset_id=document.dataset_id)
|
||||
document_was_deleted.send(document.id, dataset_id=document.dataset_id, doc_form=document.doc_form)
|
||||
|
||||
db.session.delete(document)
|
||||
db.session.commit()
|
||||
@ -1060,7 +1061,7 @@ class SegmentService:
|
||||
|
||||
# save vector index
|
||||
try:
|
||||
VectorService.create_segment_vector(args['keywords'], segment_document, dataset)
|
||||
VectorService.create_segments_vector([args['keywords']], [segment_document], dataset)
|
||||
except Exception as e:
|
||||
logging.exception("create segment index failed")
|
||||
segment_document.enabled = False
|
||||
@ -1087,6 +1088,7 @@ class SegmentService:
|
||||
).scalar()
|
||||
pre_segment_data_list = []
|
||||
segment_data_list = []
|
||||
keywords_list = []
|
||||
for segment_item in segments:
|
||||
content = segment_item['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
@ -1119,15 +1121,13 @@ class SegmentService:
|
||||
segment_document.answer = segment_item['answer']
|
||||
db.session.add(segment_document)
|
||||
segment_data_list.append(segment_document)
|
||||
pre_segment_data = {
|
||||
'segment': segment_document,
|
||||
'keywords': segment_item['keywords']
|
||||
}
|
||||
pre_segment_data_list.append(pre_segment_data)
|
||||
|
||||
pre_segment_data_list.append(segment_document)
|
||||
keywords_list.append(segment_item['keywords'])
|
||||
|
||||
try:
|
||||
# save vector index
|
||||
VectorService.multi_create_segment_vector(pre_segment_data_list, dataset)
|
||||
VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset)
|
||||
except Exception as e:
|
||||
logging.exception("create segment index failed")
|
||||
for segment_document in segment_data_list:
|
||||
@ -1157,11 +1157,18 @@ class SegmentService:
|
||||
db.session.commit()
|
||||
# update segment index task
|
||||
if args['keywords']:
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from keyword index
|
||||
kw_index.delete_by_ids([segment.index_node_id])
|
||||
# save keyword index
|
||||
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
|
||||
keyword = Keyword(dataset)
|
||||
keyword.delete_by_ids([segment.index_node_id])
|
||||
document = RAGDocument(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
keyword.add_texts([document], keywords_list=[args['keywords']])
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
|
@ -9,8 +9,8 @@ from flask_login import current_user
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Account
|
||||
@ -32,7 +32,8 @@ class FileService:
|
||||
def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
|
||||
extension = file.filename.split('.')[-1]
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
|
||||
allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \
|
||||
else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
|
||||
if extension.lower() not in allowed_extensions:
|
||||
raise UnsupportedFileTypeError()
|
||||
elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
|
||||
@ -136,7 +137,7 @@ class FileService:
|
||||
if extension.lower() not in allowed_extensions:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
text = FileExtractor.load(upload_file, return_text=True)
|
||||
text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
|
||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||
|
||||
return text
|
||||
@ -164,7 +165,7 @@ class FileService:
|
||||
return generator, upload_file.mime_type
|
||||
|
||||
@staticmethod
|
||||
def get_public_image_preview(file_id: str) -> str:
|
||||
def get_public_image_preview(file_id: str) -> tuple[Generator, str]:
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
|
@ -1,21 +1,18 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from flask import current_app
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
from services.retrieval_service import RetrievalService
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
@ -28,6 +25,7 @@ default_retrieval_model = {
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
class HitTestingService:
|
||||
@classmethod
|
||||
def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
|
||||
@ -57,61 +55,15 @@ class HitTestingService:
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
all_documents = []
|
||||
threads = []
|
||||
|
||||
# retrieval_model source with semantic
|
||||
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'top_k': retrieval_model['top_k'],
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
|
||||
'all_documents': all_documents,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings
|
||||
})
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval source with full text
|
||||
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
'top_k': retrieval_model['top_k'],
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
|
||||
'all_documents': all_documents
|
||||
})
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if retrieval_model['search_method'] == 'hybrid_search':
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=retrieval_model['reranking_model']['reranking_model_name']
|
||||
)
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
top_n=retrieval_model['top_k'],
|
||||
user=f"account-{account.id}"
|
||||
)
|
||||
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model['top_k'],
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
|
||||
@ -203,4 +155,3 @@ class HitTestingService:
|
||||
|
||||
if not query or len(query) > 250:
|
||||
raise ValueError('Query is required and cannot exceed 250 characters')
|
||||
|
||||
|
@ -1,119 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, search_method: str, embeddings: Embeddings):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = vector_index.search(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': top_k,
|
||||
'score_threshold': score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if documents:
|
||||
if reranking_model and search_method == 'semantic_search':
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model['reranking_model_name']
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents.extend(rerank_runner.run(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents)
|
||||
))
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
|
||||
@classmethod
|
||||
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, search_method: str, embeddings: Embeddings):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = vector_index.search_by_full_text_index(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
if reranking_model and search_method == 'full_text_search':
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model['reranking_model_name']
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents.extend(rerank_runner.run(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents)
|
||||
))
|
||||
else:
|
||||
all_documents.extend(documents)
|
@ -1,44 +1,18 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class VectorService:
|
||||
|
||||
@classmethod
|
||||
def create_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts([document], duplicate_check=True)
|
||||
|
||||
# save keyword index
|
||||
index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if index:
|
||||
if keywords and len(keywords) > 0:
|
||||
index.create_segment_keywords(segment.index_node_id, keywords)
|
||||
else:
|
||||
index.add_texts([document])
|
||||
|
||||
@classmethod
|
||||
def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset):
|
||||
def create_segments_vector(cls, keywords_list: Optional[list[list[str]]],
|
||||
segments: list[DocumentSegment], dataset: Dataset):
|
||||
documents = []
|
||||
for pre_segment_data in pre_segment_data_list:
|
||||
segment = pre_segment_data['segment']
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
@ -49,30 +23,26 @@ class VectorService:
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts(documents, duplicate_check=True)
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# save vector index
|
||||
vector = Vector(
|
||||
dataset=dataset
|
||||
)
|
||||
vector.add_texts(documents, duplicate_check=True)
|
||||
|
||||
# save keyword index
|
||||
keyword_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if keyword_index:
|
||||
keyword_index.multi_create_segment_keywords(pre_segment_data_list)
|
||||
keyword = Keyword(dataset)
|
||||
|
||||
if keywords_list and len(keywords_list) > 0:
|
||||
keyword.add_texts(documents, keyword_list=keywords_list)
|
||||
else:
|
||||
keyword.add_texts(documents)
|
||||
|
||||
@classmethod
|
||||
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
|
||||
# update segment index task
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
vector_index.delete_by_ids([segment.index_node_id])
|
||||
|
||||
# delete from keyword index
|
||||
kw_index.delete_by_ids([segment.index_node_id])
|
||||
|
||||
# add new index
|
||||
# format new index
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
@ -82,13 +52,20 @@ class VectorService:
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# update vector index
|
||||
vector = Vector(
|
||||
dataset=dataset
|
||||
)
|
||||
vector.delete_by_ids([segment.index_node_id])
|
||||
vector.add_texts([document], duplicate_check=True)
|
||||
|
||||
# save vector index
|
||||
if vector_index:
|
||||
vector_index.add_texts([document], duplicate_check=True)
|
||||
# update keyword index
|
||||
keyword = Keyword(dataset)
|
||||
keyword.delete_by_ids([segment.index_node_id])
|
||||
|
||||
# save keyword index
|
||||
if keywords and len(keywords) > 0:
|
||||
kw_index.create_segment_keywords(segment.index_node_id, keywords)
|
||||
keyword.add_texts([document], keywords_list=[keywords])
|
||||
else:
|
||||
kw_index.add_texts([document])
|
||||
keyword.add_texts([document])
|
||||
|
@ -4,10 +4,10 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from langchain.schema import Document
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Document as DatasetDocument
|
||||
@ -60,15 +60,9 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
if not dataset:
|
||||
raise Exception('Document has no dataset')
|
||||
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts(documents)
|
||||
|
||||
# save keyword index
|
||||
index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if index:
|
||||
index.add_texts(documents)
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.load(dataset, documents)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user