mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 19:27:48 +08:00
220 lines
10 KiB
Python
220 lines
10 KiB
Python
import json
|
|
import threading
|
|
from typing import Type, Optional, List
|
|
|
|
from flask import current_app
|
|
from langchain.tools import BaseTool
|
|
from pydantic import Field, BaseModel
|
|
|
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
|
from core.conversation_message_task import ConversationMessageTask
|
|
from core.embedding.cached_embedding import CacheEmbedding
|
|
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
|
from core.index.vector_index.vector_index import VectorIndex
|
|
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
|
from core.model_providers.model_factory import ModelFactory
|
|
from extensions.ext_database import db
|
|
from models.dataset import Dataset, DocumentSegment, Document
|
|
from services.retrieval_service import RetrievalService
|
|
|
|
default_retrieval_model = {
|
|
'search_method': 'semantic_search',
|
|
'reranking_enable': False,
|
|
'reranking_model': {
|
|
'reranking_provider_name': '',
|
|
'reranking_model_name': ''
|
|
},
|
|
'top_k': 2,
|
|
'score_threshold_enable': False
|
|
}
|
|
|
|
|
|
class DatasetRetrieverToolInput(BaseModel):
|
|
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
|
|
|
|
|
class DatasetRetrieverTool(BaseTool):
|
|
"""Tool for querying a Dataset."""
|
|
name: str = "dataset"
|
|
args_schema: Type[BaseModel] = DatasetRetrieverToolInput
|
|
description: str = "use this to retrieve a dataset. "
|
|
|
|
tenant_id: str
|
|
dataset_id: str
|
|
top_k: int = 2
|
|
score_threshold: Optional[float] = None
|
|
conversation_message_task: ConversationMessageTask
|
|
return_resource: bool
|
|
retriever_from: str
|
|
|
|
@classmethod
|
|
def from_dataset(cls, dataset: Dataset, **kwargs):
|
|
description = dataset.description
|
|
if not description:
|
|
description = 'useful for when you want to answer queries about the ' + dataset.name
|
|
|
|
description = description.replace('\n', '').replace('\r', '')
|
|
return cls(
|
|
name=f'dataset-{dataset.id}',
|
|
tenant_id=dataset.tenant_id,
|
|
dataset_id=dataset.id,
|
|
description=description,
|
|
**kwargs
|
|
)
|
|
|
|
def _run(self, query: str) -> str:
|
|
dataset = db.session.query(Dataset).filter(
|
|
Dataset.tenant_id == self.tenant_id,
|
|
Dataset.id == self.dataset_id
|
|
).first()
|
|
|
|
if not dataset:
|
|
return ''
|
|
# get retrieval model , if the model is not setting , using default
|
|
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
|
|
|
if dataset.indexing_technique == "economy":
|
|
# use keyword table query
|
|
kw_table_index = KeywordTableIndex(
|
|
dataset=dataset,
|
|
config=KeywordTableConfig(
|
|
max_keywords_per_chunk=5
|
|
)
|
|
)
|
|
|
|
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
|
|
return str("\n".join([document.page_content for document in documents]))
|
|
else:
|
|
|
|
try:
|
|
embedding_model = ModelFactory.get_embedding_model(
|
|
tenant_id=dataset.tenant_id,
|
|
model_provider_name=dataset.embedding_model_provider,
|
|
model_name=dataset.embedding_model
|
|
)
|
|
except LLMBadRequestError:
|
|
return ''
|
|
except ProviderTokenNotInitError:
|
|
return ''
|
|
|
|
embeddings = CacheEmbedding(embedding_model)
|
|
|
|
documents = []
|
|
threads = []
|
|
if self.top_k > 0:
|
|
# retrieval source with semantic
|
|
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
|
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
|
'flask_app': current_app._get_current_object(),
|
|
'dataset_id': str(dataset.id),
|
|
'query': query,
|
|
'top_k': self.top_k,
|
|
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
|
'score_threshold_enable'] else None,
|
|
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
|
|
'reranking_enable'] else None,
|
|
'all_documents': documents,
|
|
'search_method': retrieval_model['search_method'],
|
|
'embeddings': embeddings
|
|
})
|
|
threads.append(embedding_thread)
|
|
embedding_thread.start()
|
|
|
|
# retrieval_model source with full text
|
|
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
|
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
|
'flask_app': current_app._get_current_object(),
|
|
'dataset_id': str(dataset.id),
|
|
'query': query,
|
|
'search_method': retrieval_model['search_method'],
|
|
'embeddings': embeddings,
|
|
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
|
'score_threshold_enable'] else None,
|
|
'top_k': self.top_k,
|
|
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
|
|
'reranking_enable'] else None,
|
|
'all_documents': documents
|
|
})
|
|
threads.append(full_text_index_thread)
|
|
full_text_index_thread.start()
|
|
|
|
for thread in threads:
|
|
thread.join()
|
|
# hybrid search: rerank after all documents have been searched
|
|
if retrieval_model['search_method'] == 'hybrid_search':
|
|
hybrid_rerank = ModelFactory.get_reranking_model(
|
|
tenant_id=dataset.tenant_id,
|
|
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
|
|
model_name=retrieval_model['reranking_model']['reranking_model_name']
|
|
)
|
|
documents = hybrid_rerank.rerank(query, documents,
|
|
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
|
self.top_k)
|
|
else:
|
|
documents = []
|
|
|
|
hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
|
|
hit_callback.on_tool_end(documents)
|
|
document_score_list = {}
|
|
if dataset.indexing_technique != "economy":
|
|
for item in documents:
|
|
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
|
document_context_list = []
|
|
index_node_ids = [document.metadata['doc_id'] for document in documents]
|
|
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
|
|
DocumentSegment.completed_at.isnot(None),
|
|
DocumentSegment.status == 'completed',
|
|
DocumentSegment.enabled == True,
|
|
DocumentSegment.index_node_id.in_(index_node_ids)
|
|
).all()
|
|
|
|
if segments:
|
|
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
|
sorted_segments = sorted(segments,
|
|
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
|
|
float('inf')))
|
|
for segment in sorted_segments:
|
|
if segment.answer:
|
|
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
|
|
else:
|
|
document_context_list.append(segment.content)
|
|
if self.return_resource:
|
|
context_list = []
|
|
resource_number = 1
|
|
for segment in sorted_segments:
|
|
context = {}
|
|
document = Document.query.filter(Document.id == segment.document_id,
|
|
Document.enabled == True,
|
|
Document.archived == False,
|
|
).first()
|
|
if dataset and document:
|
|
source = {
|
|
'position': resource_number,
|
|
'dataset_id': dataset.id,
|
|
'dataset_name': dataset.name,
|
|
'document_id': document.id,
|
|
'document_name': document.name,
|
|
'data_source_type': document.data_source_type,
|
|
'segment_id': segment.id,
|
|
'retriever_from': self.retriever_from,
|
|
'score': document_score_list.get(segment.index_node_id, None)
|
|
|
|
}
|
|
if self.retriever_from == 'dev':
|
|
source['hit_count'] = segment.hit_count
|
|
source['word_count'] = segment.word_count
|
|
source['segment_position'] = segment.position
|
|
source['index_node_hash'] = segment.index_node_hash
|
|
if segment.answer:
|
|
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
|
else:
|
|
source['content'] = segment.content
|
|
context_list.append(source)
|
|
resource_number += 1
|
|
hit_callback.return_retriever_resource_info(context_list)
|
|
|
|
return str("\n".join(document_context_list))
|
|
|
|
async def _arun(self, tool_input: str) -> str:
|
|
raise NotImplementedError()
|