dify/api/core/agent/base_agent_runner.py

520 lines
20 KiB
Python
Raw Normal View History

import json
2024-02-01 18:11:57 +08:00
import logging
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone
from typing import Optional, Union, cast
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
ModelConfigWithCredentialsEntity,
)
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import file_manager
from core.memory.token_buffer_memory import TokenBufferMemory
2024-02-01 18:11:57 +08:00
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMUsage,
PromptMessage,
PromptMessageContent,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
2024-01-30 15:25:37 +08:00
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
2024-02-01 18:11:57 +08:00
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.entities.tool_entities import (
ToolParameter,
ToolRuntimeVariablePool,
)
2024-02-01 18:11:57 +08:00
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from factories import file_factory
from models.model import Conversation, Message, MessageAgentThought, MessageFile
2024-02-01 18:11:57 +08:00
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner):
def __init__(
self,
tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
app_config: AgentChatAppConfig,
model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity,
queue_manager: AppQueueManager,
message: Message,
user_id: str,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None,
) -> None:
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.app_config = app_config
self.model_config = model_config
self.config = config
self.queue_manager = queue_manager
self.message = message
self.user_id = user_id
self.memory = memory
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
2024-01-30 15:25:37 +08:00
self.model_instance = model_instance
# init callback
self.agent_callback = DifyAgentCallbackHandler()
# init dataset tools
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=queue_manager,
app_id=self.app_config.app_id,
message_id=message.id,
user_id=user_id,
invoke_from=self.application_generate_entity.invoke_from,
)
self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
tenant_id=tenant_id,
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback,
)
# get how many agent thoughts have been created
self.agent_thought_count = (
db.session.query(MessageAgentThought)
.filter(
MessageAgentThought.message_id == self.message.id,
)
.count()
)
db.session.close()
2024-01-30 15:25:37 +08:00
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
self.stream_tool_call = True
else:
self.stream_tool_call = False
# check if model supports vision
if model_schema and ModelFeature.VISION in (model_schema.features or []):
self.files = application_generate_entity.files
else:
self.files = []
self.query = None
self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity(
self, app_generate_entity: AgentChatAppGenerateEntity
) -> AgentChatAppGenerateEntity:
"""
Repack app generate entity
"""
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
return app_generate_entity
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
"""
convert tool to prompt message tool
"""
2024-03-08 20:31:13 +08:00
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
2024-04-23 15:22:42 +08:00
app_id=self.app_config.app_id,
2024-03-08 20:31:13 +08:00
agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from,
)
tool_entity.load_variables(self.variables_pool)
message_tool = PromptMessageTool(
name=tool.tool_name,
description=tool_entity.description.llm,
parameters={
"type": "object",
"properties": {},
"required": [],
},
)
2024-03-08 20:31:13 +08:00
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
2024-03-08 20:31:13 +08:00
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
message_tool.parameters["properties"][parameter.name] = {
2024-03-08 20:31:13 +08:00
"type": parameter_type,
"description": parameter.llm_description or "",
2024-03-08 20:31:13 +08:00
}
2024-03-08 20:31:13 +08:00
if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum
2024-03-08 20:31:13 +08:00
if parameter.required:
message_tool.parameters["required"].append(parameter.name)
return message_tool, tool_entity
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
"""
convert dataset retriever tool to prompt message tool
"""
prompt_tool = PromptMessageTool(
name=tool.identity.name,
description=tool.description.llm,
parameters={
"type": "object",
"properties": {},
"required": [],
},
)
for parameter in tool.get_runtime_parameters():
parameter_type = "string"
prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.required:
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)
return prompt_tool
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
2024-04-11 18:34:17 +08:00
"""
Init tools
"""
tool_instances = {}
prompt_messages_tools = []
for tool in self.app_config.agent.tools if self.app_config.agent else []:
try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
# api tool may be deleted
continue
# save tool entity
tool_instances[tool.tool_name] = tool_entity
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# convert dataset tools into ModelRuntime Tool format
for dataset_tool in self.dataset_tools:
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
tool_instances[dataset_tool.identity.name] = dataset_tool
return tool_instances, prompt_messages_tools
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
"""
update prompt message tool
"""
# try to get tool runtime parameters
tool_runtime_parameters = tool.get_runtime_parameters() or []
for parameter in tool_runtime_parameters:
2024-03-08 20:31:13 +08:00
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
prompt_tool.parameters["properties"][parameter.name] = {
2024-03-08 20:31:13 +08:00
"type": parameter_type,
"description": parameter.llm_description or "",
2024-03-08 20:31:13 +08:00
}
2024-03-08 20:31:13 +08:00
if len(enum) > 0:
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
2024-03-08 20:31:13 +08:00
if parameter.required:
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)
return prompt_tool
def create_agent_thought(
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought:
"""
Create agent thought
"""
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
tool_meta_str="{}",
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=0,
message_price_unit=0,
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
answer_unit_price=0,
answer_price_unit=0,
tokens=0,
total_price=0,
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
created_by_role="account",
created_by=self.user_id,
)
db.session.add(thought)
db.session.commit()
db.session.refresh(thought)
db.session.close()
self.agent_thought_count += 1
return thought
def save_agent_thought(
self,
agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage = None,
) -> MessageAgentThought:
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
if thought is not None:
agent_thought.thought = thought
if tool_name is not None:
agent_thought.tool = tool_name
if tool_input is not None:
if isinstance(tool_input, dict):
try:
tool_input = json.dumps(tool_input, ensure_ascii=False)
except Exception as e:
tool_input = json.dumps(tool_input)
agent_thought.tool_input = tool_input
if observation is not None:
if isinstance(observation, dict):
try:
observation = json.dumps(observation, ensure_ascii=False)
except Exception as e:
observation = json.dumps(observation)
agent_thought.observation = observation
if answer is not None:
agent_thought.answer = answer
if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids)
if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit
agent_thought.message_unit_price = llm_usage.prompt_unit_price
agent_thought.answer_token = llm_usage.completion_tokens
agent_thought.answer_price_unit = llm_usage.completion_price_unit
agent_thought.answer_unit_price = llm_usage.completion_unit_price
agent_thought.tokens = llm_usage.total_tokens
agent_thought.total_price = llm_usage.total_price
2024-01-24 20:14:45 +08:00
# check if tool labels is not empty
labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(";") if agent_thought.tool else []
2024-01-24 20:14:45 +08:00
for tool in tools:
if not tool:
continue
if tool not in labels:
tool_label = ToolManager.get_tool_label(tool)
if tool_label:
labels[tool] = tool_label.to_dict()
else:
labels[tool] = {"en_US": tool, "zh_Hans": tool}
2024-01-24 20:14:45 +08:00
agent_thought.tool_labels_str = json.dumps(labels)
if tool_invoke_meta is not None:
if isinstance(tool_invoke_meta, dict):
try:
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
except Exception as e:
tool_invoke_meta = json.dumps(tool_invoke_meta)
agent_thought.tool_meta_str = tool_invoke_meta
db.session.commit()
db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
"""
convert tool variables to db variables
"""
db_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
db.session.close()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize agent history
"""
result = []
# check if there is a system message in the beginning of the conversation
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
result.append(prompt_message)
messages: list[Message] = (
db.session.query(Message)
.filter(
Message.conversation_id == self.message.conversation_id,
)
.order_by(Message.created_at.desc())
.all()
)
messages = list(reversed(extract_thread_messages(messages)))
for message in messages:
if message.id == self.message.id:
continue
result.append(self.organize_agent_user_prompt(message))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception as e:
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception as e:
tool_responses = dict.fromkeys(tools, agent_thought.observation)
for tool in tools:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
),
)
)
tool_call_response.append(
ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
name=tool,
tool_call_id=tool_call_id,
)
)
result.extend(
[
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response,
]
)
if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
result.append(AssistantPromptMessage(content=message.answer))
db.session.close()
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
)
else:
file_objs = []
if not file_objs:
return UserPromptMessage(content=message.query)
else:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file_obj in file_objs:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
return UserPromptMessage(content=prompt_message_contents)
else:
return UserPromptMessage(content=message.query)