Fix/agent external knowledge retrieval (#9241)

This commit is contained in:
Jyong 2024-10-11 19:21:03 +08:00 committed by GitHub
parent 44f6a536d2
commit 42b02b3a5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 225 additions and 92 deletions

View File

@ -191,6 +191,22 @@ class CeleryConfig(DatabaseConfig):
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
class InternalTestConfig(BaseSettings):
"""
Configuration settings for Internal Test
"""
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
description="Internal test AWS secret access key",
default=None,
)
AWS_ACCESS_KEY_ID: Optional[str] = Field(
description="Internal test AWS access key ID",
default=None,
)
class MiddlewareConfig( class MiddlewareConfig(
# place the configs in alphabet order # place the configs in alphabet order
CeleryConfig, CeleryConfig,
@ -224,5 +240,6 @@ class MiddlewareConfig(
TiDBVectorConfig, TiDBVectorConfig,
WeaviateConfig, WeaviateConfig,
ElasticsearchConfig, ElasticsearchConfig,
InternalTestConfig,
): ):
pass pass

View File

@ -13,6 +13,7 @@ from libs.login import login_required
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
def _validate_name(name): def _validate_name(name):
@ -232,8 +233,31 @@ class ExternalKnowledgeHitTestingApi(Resource):
raise InternalServerError(str(e)) raise InternalServerError(str(e))
class BedrockRetrievalApi(Resource):
# this api is only for internal testing
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
parser.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
args = parser.parse_args()
# Call the knowledge retrieval service
result = ExternalDatasetTestService.knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"]
)
return result, 200
api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing") api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
api.add_resource(ExternalDatasetCreateApi, "/datasets/external") api.add_resource(ExternalDatasetCreateApi, "/datasets/external")
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api") api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api")
api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>") api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check") api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
# this api is only for internal test
api.add_resource(BedrockRetrievalApi, "/test/retrieval")

View File

@ -539,7 +539,7 @@ class DatasetRetrieval:
continue continue
# pass if dataset is not available # pass if dataset is not available
if dataset and dataset.available_document_count == 0: if dataset and dataset.provider != "external" and dataset.available_document_count == 0:
continue continue
available_datasets.append(dataset) available_datasets.append(dataset)

View File

@ -1,10 +1,12 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = { default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@ -53,97 +55,137 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id) hit_callback.on_query(query, dataset.id)
if dataset.provider == "external":
# get retrieval model , if the model is not setting , using default results = []
retrieval_model = dataset.retrieval_model or default_retrieval_model external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
if dataset.indexing_technique == "economy": tenant_id=dataset.tenant_id,
# use keyword table query dataset_id=dataset.id,
documents = RetrievalService.retrieve( query=query,
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k external_retrieval_parameters=dataset.retrieval_model,
) )
return str("\n".join([document.page_content for document in documents])) for external_document in external_documents:
else: document = RetrievalDocument(
if self.top_k > 0: page_content=external_document.get("content"),
# retrieval source metadata=external_document.get("metadata"),
documents = RetrievalService.retrieve( provider="external",
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
) )
else: document.metadata["score"] = external_document.get("score")
documents = [] document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
# deal with external documents
context_list = []
for position, item in enumerate(results, start=1):
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
context_list.append(source)
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents) hit_callback.return_retriever_resource_info(context_list)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments: return str("\n".join([item.page_content for item in results]))
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} else:
sorted_segments = sorted( # get retrieval model , if the model is not setting , using default
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
) )
for segment in sorted_segments: return str("\n".join([document.page_content for document in documents]))
if segment.answer: else:
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") if self.top_k > 0:
else: # retrieval source
document_context_list.append(segment.get_sign_content()) documents = RetrievalService.retrieve(
if self.return_resource: retrieval_method=retrieval_model.get("search_method", "semantic_search"),
context_list = [] dataset_id=dataset.id,
resource_number = 1 query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
else:
documents = []
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments: for segment in sorted_segments:
context = {} if segment.answer:
document = Document.query.filter( document_context_list.append(
Document.id == segment.document_id, f"question:{segment.get_sign_content()} answer:{segment.answer}"
Document.enabled == True, )
Document.archived == False, else:
).first() document_context_list.append(segment.get_sign_content())
if dataset and document: if self.return_resource:
source = { context_list = []
"position": resource_number, resource_number = 1
"dataset_id": dataset.id, for segment in sorted_segments:
"dataset_name": dataset.name, context = {}
"document_id": document.id, document = Document.query.filter(
"document_name": document.name, Document.id == segment.document_id,
"data_source_type": document.data_source_type, Document.enabled == True,
"segment_id": segment.id, Document.archived == False,
"retriever_from": self.retriever_from, ).first()
"score": document_score_list.get(segment.index_node_id, None), if dataset and document:
} source = {
if self.retriever_from == "dev": "position": resource_number,
source["hit_count"] = segment.hit_count "dataset_id": dataset.id,
source["word_count"] = segment.word_count "dataset_name": dataset.name,
source["segment_position"] = segment.position "document_id": document.id,
source["index_node_hash"] = segment.index_node_hash "document_name": document.name,
if segment.answer: "data_source_type": document.data_source_type,
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" "segment_id": segment.id,
else: "retriever_from": self.retriever_from,
source["content"] = segment.content "score": document_score_list.get(segment.index_node_id, None),
context_list.append(source) }
resource_number += 1 if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
context_list.append(source)
resource_number += 1
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list) hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list)) return str("\n".join(document_context_list))

View File

@ -79,8 +79,9 @@ class KnowledgeRetrievalNode(BaseNode):
results = ( results = (
db.session.query(Dataset) db.session.query(Dataset)
.join(subquery, Dataset.id == subquery.c.dataset_id) .outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) .filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all() .all()
) )
@ -121,10 +122,13 @@ class KnowledgeRetrievalNode(BaseNode):
) )
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
reranking_model = { if node_data.multiple_retrieval_config.reranking_model:
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, reranking_model = {
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
} "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
reranking_model = None
weights = None weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
reranking_model = None reranking_model = None

View File

@ -234,6 +234,7 @@ class DatasetService:
dataset.name = data.get("name", dataset.name) dataset.name = data.get("name", dataset.name)
dataset.description = data.get("description", "") dataset.description = data.get("description", "")
external_knowledge_id = data.get("external_knowledge_id", None) external_knowledge_id = data.get("external_knowledge_id", None)
dataset.permission = data.get("permission")
db.session.add(dataset) db.session.add(dataset)
if not external_knowledge_id: if not external_knowledge_id:
raise ValueError("External knowledge id is required.") raise ValueError("External knowledge id is required.")

View File

@ -0,0 +1,45 @@
import boto3
from configs import dify_config
class ExternalDatasetTestService:
# this service is only for internal testing
@staticmethod
def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str):
# get bedrock client
client = boto3.client(
"bedrock-agent-runtime",
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
# example: us-east-1
region_name="us-east-1",
)
# fetch external knowledge retrieval
response = client.retrieve(
knowledgeBaseId=knowledge_id,
retrievalConfiguration={
"vectorSearchConfiguration": {
"numberOfResults": retrieval_setting.get("top_k"),
"overrideSearchType": "HYBRID",
}
},
retrievalQuery={"text": query},
)
# parse response
results = []
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
if response.get("retrievalResults"):
retrieval_results = response.get("retrievalResults")
for retrieval_result in retrieval_results:
# filter out results with score less than threshold
if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0):
continue
result = {
"metadata": retrieval_result.get("metadata"),
"score": retrieval_result.get("score"),
"title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"),
"content": retrieval_result.get("content").get("text"),
}
results.append(result)
return {"records": results}