fix: dataset desc (#1045)

This commit is contained in:
takatost 2023-08-29 09:07:27 +08:00 committed by GitHub
parent a55ba6e614
commit 7b3314c5db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 17 additions and 18 deletions

View File

@ -52,7 +52,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:

View File

@ -45,7 +45,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15
self.llm.max_tokens = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()

View File

@ -90,7 +90,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)

View File

@ -1,5 +1,6 @@
import json
import logging
from json import JSONDecodeError
from typing import Any, Dict, List, Union, Optional
@ -44,10 +45,15 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
input_str: str,
**kwargs: Any,
) -> None:
# tool_name = serialized.get('name')
input_dict = json.loads(input_str.replace("'", "\""))
dataset_id = input_dict.get('dataset_id')
query = input_dict.get('query')
tool_name: str = serialized.get('name')
dataset_id = tool_name.removeprefix('dataset-')
try:
input_dict = json.loads(input_str.replace("'", "\""))
query = input_dict.get('query')
except JSONDecodeError:
query = input_str
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
def on_tool_end(

View File

@ -1,4 +1,3 @@
import re
from typing import Type
from flask import current_app
@ -16,7 +15,6 @@ from models.dataset import Dataset, DocumentSegment
class DatasetRetrieverToolInput(BaseModel):
dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
@ -37,27 +35,22 @@ class DatasetRetrieverTool(BaseTool):
description = 'useful for when you want to answer queries about the ' + dataset.name
description = description.replace('\n', '').replace('\r', '')
description += '\nID of dataset MUST be ' + dataset.id
return cls(
name=f'dataset-{dataset.id}',
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
**kwargs
)
def _run(self, dataset_id: str, query: str) -> str:
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match = re.search(pattern, dataset_id, re.IGNORECASE)
if match:
dataset_id = match.group()
def _run(self, query: str) -> str:
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
Dataset.id == self.dataset_id
).first()
if not dataset:
return f'[{self.name} failed to find dataset with id {dataset_id}.]'
return f'[{self.name} failed to find dataset with id {self.dataset_id}.]'
if dataset.indexing_technique == "economy":
# use keyword table query