mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 03:07:59 +08:00
feat: add tool labels (#2178)
This commit is contained in:
parent
0940084fd2
commit
7cb75cb2e7
@ -44,6 +44,7 @@ class MessageListApi(AppApiResource):
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_labels': fields.Raw,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'observation': fields.String,
|
||||
|
@ -18,6 +18,7 @@ from core.model_runtime.entities.message_entities import (AssistantPromptMessage
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
@ -281,7 +282,7 @@ class GenerateTaskPipeline:
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
agent_thought = (
|
||||
agent_thought: MessageAgentThought = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||
.first()
|
||||
@ -298,6 +299,7 @@ class GenerateTaskPipeline:
|
||||
'thought': agent_thought.thought,
|
||||
'observation': agent_thought.observation,
|
||||
'tool': agent_thought.tool,
|
||||
'tool_labels': agent_thought.tool_labels,
|
||||
'tool_input': agent_thought.tool_input,
|
||||
'created_at': int(self._message.created_at.timestamp()),
|
||||
'message_files': agent_thought.files
|
||||
|
@ -396,6 +396,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
message_chain_id=None,
|
||||
thought='',
|
||||
tool=tool_name,
|
||||
tool_labels_str='{}',
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
@ -469,6 +470,21 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
agent_thought.tokens = llm_usage.total_tokens
|
||||
agent_thought.total_price = llm_usage.total_price
|
||||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(';') if agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
if tool not in labels:
|
||||
tool_label = ToolManager.get_tool_label(tool)
|
||||
if tool_label:
|
||||
labels[tool] = tool_label.to_dict()
|
||||
else:
|
||||
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def get_history_prompt_messages(self) -> List[PromptMessage]:
|
||||
|
@ -31,6 +31,7 @@ import mimetypes
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_builtin_providers = {}
|
||||
_builtin_tools_labels = {}
|
||||
|
||||
class ToolManager:
|
||||
@staticmethod
|
||||
@ -233,7 +234,7 @@ class ToolManager:
|
||||
if len(_builtin_providers) > 0:
|
||||
return list(_builtin_providers.values())
|
||||
|
||||
builtin_providers = []
|
||||
builtin_providers: List[BuiltinToolProviderController] = []
|
||||
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
|
||||
if provider.startswith('__'):
|
||||
continue
|
||||
@ -264,8 +265,30 @@ class ToolManager:
|
||||
# cache the builtin providers
|
||||
for provider in builtin_providers:
|
||||
_builtin_providers[provider.identity.name] = provider
|
||||
for tool in provider.get_tools():
|
||||
_builtin_tools_labels[tool.identity.name] = tool.identity.label
|
||||
|
||||
return builtin_providers
|
||||
|
||||
@staticmethod
|
||||
def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
|
||||
"""
|
||||
get the tool label
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
|
||||
:return: the label of the tool
|
||||
"""
|
||||
global _builtin_tools_labels
|
||||
if len(_builtin_tools_labels) == 0:
|
||||
# init the builtin providers
|
||||
ToolManager.list_builtin_providers()
|
||||
|
||||
if tool_name not in _builtin_tools_labels:
|
||||
return None
|
||||
|
||||
return _builtin_tools_labels[tool_name]
|
||||
|
||||
@staticmethod
|
||||
def user_list_providers(
|
||||
user_id: str,
|
||||
|
@ -49,10 +49,11 @@ agent_thought_fields = {
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_labels': fields.Raw,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'observation': fields.String,
|
||||
'files': fields.List(fields.String)
|
||||
'files': fields.List(fields.String),
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
|
@ -36,6 +36,7 @@ agent_thought_fields = {
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_labels': fields.Raw,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'observation': fields.String,
|
||||
|
@ -0,0 +1,32 @@
|
||||
"""add tool labels to agent thought
|
||||
|
||||
Revision ID: 380c6aa5a70d
|
||||
Revises: dfb3b7f477da
|
||||
Create Date: 2024-01-24 10:58:15.644445
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '380c6aa5a70d'
|
||||
down_revision = 'dfb3b7f477da'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tool_labels_str', sa.Text(), server_default=sa.text("'{}'::text"), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.drop_column('tool_labels_str')
|
||||
|
||||
# ### end Alembic commands ###
|
@ -1003,6 +1003,7 @@ class MessageAgentThought(db.Model):
|
||||
position = db.Column(db.Integer, nullable=False)
|
||||
thought = db.Column(db.Text, nullable=True)
|
||||
tool = db.Column(db.Text, nullable=True)
|
||||
tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
|
||||
tool_input = db.Column(db.Text, nullable=True)
|
||||
observation = db.Column(db.Text, nullable=True)
|
||||
# plugin_id = db.Column(UUID, nullable=True) ## for future design
|
||||
@ -1030,6 +1031,16 @@ class MessageAgentThought(db.Model):
|
||||
return json.loads(self.message_files)
|
||||
else:
|
||||
return []
|
||||
|
||||
@property
|
||||
def tool_labels(self) -> dict:
|
||||
try:
|
||||
if self.tool_labels_str:
|
||||
return json.loads(self.tool_labels_str)
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
class DatasetRetrieverResource(db.Model):
|
||||
__tablename__ = 'dataset_retriever_resources'
|
||||
|
Loading…
Reference in New Issue
Block a user