From 6c148b223d9036d6d2a4128ffe12ce7b83759a78 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 26 Aug 2023 17:35:17 +0800 Subject: [PATCH] fix: dataset query truncated (#1026) --- api/core/agent/agent/multi_dataset_router_agent.py | 8 +++++++- api/core/agent/agent/openai_function_call.py | 7 +++++++ .../agent/agent/structed_multi_dataset_router_agent.py | 8 +++++++- api/core/agent/agent/structured_chat.py | 8 +++++++- 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index a2c0b4998..958183f08 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -60,7 +60,13 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): return AgentFinish(return_values={"output": observation}, log=observation) try: - return super().plan(intermediate_steps, callbacks, **kwargs) + agent_decision = super().plan(intermediate_steps, callbacks, **kwargs) + if isinstance(agent_decision, AgentAction): + tool_inputs = agent_decision.tool_input + if isinstance(tool_inputs, dict) and 'query' in tool_inputs: + tool_inputs['query'] = kwargs['input'] + agent_decision.tool_input = tool_inputs + return agent_decision except Exception as e: new_exception = self.model_instance.handle_exceptions(e) raise new_exception diff --git a/api/core/agent/agent/openai_function_call.py b/api/core/agent/agent/openai_function_call.py index a15e0bd8e..359078607 100644 --- a/api/core/agent/agent/openai_function_call.py +++ b/api/core/agent/agent/openai_function_call.py @@ -97,6 +97,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio messages, functions=self.functions, callbacks=callbacks ) agent_decision = _parse_ai_message(predicted_message) + + if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': + tool_inputs = agent_decision.tool_input + if isinstance(tool_inputs, dict) and 'query' in tool_inputs: + tool_inputs['query'] = kwargs['input'] + agent_decision.tool_input = tool_inputs + return agent_decision @classmethod diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/agent/agent/structed_multi_dataset_router_agent.py index 6e2198970..67522418d 100644 --- a/api/core/agent/agent/structed_multi_dataset_router_agent.py +++ b/api/core/agent/agent/structed_multi_dataset_router_agent.py @@ -102,7 +102,13 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): raise new_exception try: - return self.output_parser.parse(full_output) + agent_decision = self.output_parser.parse(full_output) + if isinstance(agent_decision, AgentAction): + tool_inputs = agent_decision.tool_input + if isinstance(tool_inputs, dict) and 'query' in tool_inputs: + tool_inputs['query'] = kwargs['input'] + agent_decision.tool_input = tool_inputs + return agent_decision except OutputParserException: return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " "I don't know how to respond to that."}, "") diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py index 9d4f4d608..77635273e 100644 --- a/api/core/agent/agent/structured_chat.py +++ b/api/core/agent/agent/structured_chat.py @@ -106,7 +106,13 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): raise new_exception try: - return self.output_parser.parse(full_output) + agent_decision = self.output_parser.parse(full_output) + if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': + tool_inputs = agent_decision.tool_input + if isinstance(tool_inputs, dict) and 'query' in tool_inputs: + tool_inputs['query'] = kwargs['input'] + agent_decision.tool_input = tool_inputs + return agent_decision except OutputParserException: return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " "I don't know how to respond to that."}, "")