dify/api/core/agent/cot_agent_runner.py

433 lines
17 KiB
Python
Raw Normal View History

import json
2024-04-11 18:34:17 +08:00
from abc import ABC, abstractmethod
from collections.abc import Generator
2024-04-11 18:34:17 +08:00
from typing import Union
from core.agent.base_agent_runner import BaseAgentRunner
2024-04-11 18:34:17 +08:00
from core.agent.entities import AgentScratchpadUnit
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
2024-02-01 18:11:57 +08:00
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
2024-04-11 18:34:17 +08:00
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
from models.model import Message
2024-02-01 18:11:57 +08:00
2024-04-11 18:34:17 +08:00
class CotAgentRunner(BaseAgentRunner, ABC):
2024-03-05 14:02:07 +08:00
_is_first_iteration = True
_ignore_observation_providers = ['wenxin']
2024-04-11 18:34:17 +08:00
_historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None
_query: str = None
_prompt_messages_tools: list[PromptMessage] = None
2024-03-05 14:02:07 +08:00
def run(self, message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
"""
Run Cot agent application
"""
app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity)
2024-04-11 18:34:17 +08:00
self._init_react_state(query)
# check model mode
if 'Observation' not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append('Observation')
app_config = self.app_config
2024-04-11 18:34:17 +08:00
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
2024-04-11 18:34:17 +08:00
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# convert tools into ModelRuntime Tool format
2024-04-11 18:34:17 +08:00
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
2024-04-11 18:34:17 +08:00
prompt_messages = self._organize_prompt_messages()
function_call_state = True
llm_usage = {
'usage': None
}
final_answer = ''
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
else:
llm_usage = final_llm_usage_dict['usage']
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
2024-01-30 15:25:37 +08:00
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call
function_call_state = False
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
2024-04-11 18:34:17 +08:00
self._prompt_messages_tools = []
message_file_ids = []
2024-01-24 12:09:30 +08:00
agent_thought = self.create_agent_thought(
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
)
2024-01-24 12:09:30 +08:00
if iteration_step > 1:
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
2024-03-04 14:15:53 +08:00
# recalc llm max tokens
2024-04-11 18:34:17 +08:00
prompt_messages = self._organize_prompt_messages()
2024-03-04 13:32:17 +08:00
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
2024-02-21 10:45:59 +08:00
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=[],
stop=app_generate_entity.model_conf.stop,
2024-02-21 10:45:59 +08:00
stream=True,
user=self.user_id,
callbacks=[],
)
# check llm result
2024-02-21 10:45:59 +08:00
if not chunks:
raise ValueError("failed to invoke llm")
2024-01-24 12:09:30 +08:00
2024-02-21 10:45:59 +08:00
usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
2024-02-21 10:45:59 +08:00
scratchpad = AgentScratchpadUnit(
agent_response='',
thought='',
action_str='',
observation='',
action=None,
2024-02-21 10:45:59 +08:00
)
2024-01-24 12:09:30 +08:00
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
2024-01-24 12:09:30 +08:00
2024-02-21 10:45:59 +08:00
for chunk in react_chunks:
2024-04-11 18:34:17 +08:00
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
2024-04-11 18:34:17 +08:00
scratchpad.action = action
2024-02-21 10:45:59 +08:00
else:
scratchpad.agent_response += chunk
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=chunk
),
usage=None
)
)
2024-03-05 14:02:07 +08:00
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
2024-04-11 18:34:17 +08:00
self._agent_scratchpad.append(scratchpad)
2024-02-21 10:45:59 +08:00
# get llm usage
if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict['usage'])
else:
usage_dict['usage'] = LLMUsage.empty_usage()
2024-04-11 18:34:17 +08:00
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought,
observation='',
answer=scratchpad.agent_response,
messages_ids=[],
llm_usage=usage_dict['usage']
)
2024-04-11 18:34:17 +08:00
if not scratchpad.is_final():
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = ''
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
2024-04-11 18:34:17 +08:00
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f'{scratchpad.action.action_input}'
except json.JSONDecodeError:
final_answer = f'{scratchpad.action.action_input}'
else:
function_call_state = True
# action is tool call, invoke tool
2024-04-11 18:34:17 +08:00
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
action=scratchpad.action,
tool_instances=tool_instances,
message_file_ids=message_file_ids
)
scratchpad.observation = tool_invoke_response
scratchpad.agent_response = tool_invoke_response
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought,
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
2024-04-11 18:34:17 +08:00
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict['usage']
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool message
2024-04-11 18:34:17 +08:00
for prompt_tool in self._prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage']
),
system_fingerprint=''
)
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name='',
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation={},
answer=final_answer,
messages_ids=[]
)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer
),
2024-01-24 15:34:17 +08:00
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
2024-04-11 18:34:17 +08:00
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str]) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
:param action: action
:param tool_instances: tool instances
:return: observation, meta
"""
# action is tool call, invoke tool
tool_call_name = action.action_name
tool_call_args = action.action_input
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
answer = f"there is not a tool named {tool_call_name}"
return answer, ToolInvokeMeta.error_instance(answer)
if isinstance(tool_call_args, str):
2024-02-21 10:45:59 +08:00
try:
2024-04-11 18:34:17 +08:00
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
# invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback
)
2024-02-21 10:45:59 +08:00
2024-04-11 18:34:17 +08:00
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), PublishFrom.APPLICATION_MANAGER)
# add message file ids
message_file_ids.append(message_file.id)
return tool_invoke_response, tool_invoke_meta
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
"""
convert dict to action
"""
return AgentScratchpadUnit.Action(
action_name=action['action'],
action_input=action['action_input']
)
2024-02-21 10:45:59 +08:00
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
"""
fill in inputs from external data tools
"""
for key, value in inputs.items():
try:
instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
except Exception as e:
continue
return instruction
2024-04-11 18:34:17 +08:00
def _init_react_state(self, query) -> None:
"""
init agent scratchpad
"""
2024-04-11 18:34:17 +08:00
self._query = query
self._agent_scratchpad = []
self._historic_prompt_messages = self._organize_historic_prompt_messages()
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
organize prompt messages
"""
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
format assistant message
"""
message = ''
for scratchpad in agent_scratchpad:
if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}"
else:
message += f"Thought: {scratchpad.thought}\n\n"
if scratchpad.action_str:
message += f"Action: {scratchpad.action_str}\n\n"
if scratchpad.observation:
message += f"Observation: {scratchpad.observation}\n\n"
return message
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
2024-04-11 18:34:17 +08:00
"""
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None
2024-04-11 18:34:17 +08:00
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=self.history_prompt_messages,
memory=self.memory
).get_prompt()
2024-04-11 18:34:17 +08:00
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or 'I am thinking about how to help you',
action_str='',
action=None,
observation=None,
)
scratchpads.append(current_scratchpad)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments)
)
2024-04-11 18:34:17 +08:00
current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict()
)
except:
pass
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
current_scratchpad.observation = message.content
2024-04-11 18:34:17 +08:00
elif isinstance(message, UserPromptMessage):
if scratchpads:
2024-04-11 18:34:17 +08:00
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
2024-04-11 18:34:17 +08:00
))
scratchpads = []
current_scratchpad = None
result.append(message)
if scratchpads:
2024-04-11 18:34:17 +08:00
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
2024-04-11 18:34:17 +08:00
))
2024-04-11 18:34:17 +08:00
return result